mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add IP whitelisting to --host
This commit is contained in:
91
aiserver.py
91
aiserver.py
@@ -73,6 +73,8 @@ import torch
|
|||||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification
|
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification
|
||||||
from transformers import __version__ as transformers_version
|
from transformers import __version__ as transformers_version
|
||||||
import transformers
|
import transformers
|
||||||
|
import ipaddress
|
||||||
|
from functools import wraps
|
||||||
try:
|
try:
|
||||||
from transformers.models.opt.modeling_opt import OPTDecoder
|
from transformers.models.opt.modeling_opt import OPTDecoder
|
||||||
except:
|
except:
|
||||||
@@ -86,6 +88,9 @@ from PIL import Image
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
global tpu_mtj_backend
|
global tpu_mtj_backend
|
||||||
|
global allowed_ips
|
||||||
|
allowed_ips = set() # empty set
|
||||||
|
enable_whitelist = False
|
||||||
|
|
||||||
|
|
||||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||||
@@ -1469,13 +1474,15 @@ def spRequest(filename):
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
def general_startup(override_args=None):
|
def general_startup(override_args=None):
|
||||||
global args
|
global args
|
||||||
|
global enable_whitelist
|
||||||
|
global allowed_ips
|
||||||
# Parsing Parameters
|
# Parsing Parameters
|
||||||
parser = argparse.ArgumentParser(description="KoboldAI Server")
|
parser = argparse.ArgumentParser(description="KoboldAI Server")
|
||||||
parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI for Remote Play")
|
parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI for Remote Play")
|
||||||
parser.add_argument("--noaimenu", action='store_true', help="Disables the ability to select the AI")
|
parser.add_argument("--noaimenu", action='store_true', help="Disables the ability to select the AI")
|
||||||
parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for Remote Play using Ngrok")
|
parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for Remote Play using Ngrok")
|
||||||
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
|
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
|
||||||
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service")
|
parser.add_argument("--host", type=str, default="", nargs="?", const="", help="Optimizes KoboldAI for LAN Remote Play without using a proxy service. --host opens to all LAN. Enable IP whitelisting by using a comma separated IP list. Supports individual IPs, ranges, and subnets --host 127.0.0.1,127.0.0.2,127.0.0.3,192.168.1.0-192.168.1.255,10.0.0.0/24,etc")
|
||||||
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
|
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
|
||||||
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
|
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
|
||||||
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
||||||
@@ -1525,6 +1532,10 @@ def general_startup(override_args=None):
|
|||||||
|
|
||||||
utils.args = args
|
utils.args = args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#load system and user settings
|
#load system and user settings
|
||||||
for setting in ['user_settings', 'system_settings']:
|
for setting in ['user_settings', 'system_settings']:
|
||||||
@@ -1602,9 +1613,32 @@ def general_startup(override_args=None):
|
|||||||
if args.localtunnel:
|
if args.localtunnel:
|
||||||
koboldai_vars.host = True;
|
koboldai_vars.host = True;
|
||||||
|
|
||||||
|
if args.host == "":
|
||||||
|
koboldai_vars.host = True
|
||||||
|
args.unblock = True
|
||||||
if args.host:
|
if args.host:
|
||||||
koboldai_vars.host = True;
|
# This means --host option was submitted without an argument
|
||||||
args.unblock = True;
|
# Enable all LAN IPs (0.0.0.0/0)
|
||||||
|
if args.host != "":
|
||||||
|
# Check if --host option was submitted with an argument
|
||||||
|
# Parse the supplied IP(s) and add them to the allowed IPs list
|
||||||
|
koboldai_vars.host = True
|
||||||
|
args.unblock = True
|
||||||
|
enable_whitelist = True
|
||||||
|
for ip_str in args.host.split(","):
|
||||||
|
if "/" in ip_str:
|
||||||
|
allowed_ips |= set(str(ip) for ip in ipaddress.IPv4Network(ip_str, strict=False).hosts())
|
||||||
|
elif "-" in ip_str:
|
||||||
|
start_ip, end_ip = ip_str.split("-")
|
||||||
|
start_ip_int = int(ipaddress.IPv4Address(start_ip))
|
||||||
|
end_ip_int = int(ipaddress.IPv4Address(end_ip))
|
||||||
|
allowed_ips |= set(str(ipaddress.IPv4Address(ip)) for ip in range(start_ip_int, end_ip_int + 1))
|
||||||
|
else:
|
||||||
|
allowed_ips.add(ip_str.strip())
|
||||||
|
# Sort and print the allowed IPs list
|
||||||
|
allowed_ips = sorted(allowed_ips, key=lambda ip: int(''.join([i.zfill(3) for i in ip.split('.')])))
|
||||||
|
print(f"Allowed IPs: {allowed_ips}")
|
||||||
|
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
koboldai_vars.use_colab_tpu = False
|
koboldai_vars.use_colab_tpu = False
|
||||||
@@ -3450,22 +3484,52 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + koboldai_vars.cloudflare_link + format(colors.END))
|
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + koboldai_vars.cloudflare_link + format(colors.END))
|
||||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + koboldai_vars.cloudflare_link + "/new_ui" + format(colors.END))
|
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + koboldai_vars.cloudflare_link + "/new_ui" + format(colors.END))
|
||||||
|
|
||||||
|
|
||||||
|
# Setup IP Whitelisting
|
||||||
|
# Define a function to check if IP is allowed
|
||||||
|
def is_allowed_ip():
|
||||||
|
global allowed_ips
|
||||||
|
client_ip = request.remote_addr
|
||||||
|
if request.path != '/genre_data.json':
|
||||||
|
print("Connection Attempt: " + request.remote_addr)
|
||||||
|
print("Allowed?: ", request.remote_addr in allowed_ips)
|
||||||
|
return client_ip in allowed_ips
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Define a decorator to enforce IP whitelisting
|
||||||
|
def require_allowed_ip(func):
|
||||||
|
@wraps(func)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
|
||||||
|
if enable_whitelist and not is_allowed_ip():
|
||||||
|
return abort(403)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Set up Flask routes
|
# Set up Flask routes
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
@app.route('/index')
|
@app.route('/index')
|
||||||
|
@require_allowed_ip
|
||||||
def index():
|
def index():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
return redirect('/api/latest')
|
return redirect('/api/latest')
|
||||||
else:
|
else:
|
||||||
return render_template('index.html', hide_ai_menu=args.noaimenu)
|
return render_template('index.html', hide_ai_menu=args.noaimenu)
|
||||||
@app.route('/api', strict_slashes=False)
|
@app.route('/api', strict_slashes=False)
|
||||||
|
@require_allowed_ip
|
||||||
def api():
|
def api():
|
||||||
return redirect('/api/latest')
|
return redirect('/api/latest')
|
||||||
@app.route('/favicon.ico')
|
@app.route('/favicon.ico')
|
||||||
|
|
||||||
def favicon():
|
def favicon():
|
||||||
return send_from_directory(app.root_path,
|
return send_from_directory(app.root_path,
|
||||||
'koboldai.ico', mimetype='image/vnd.microsoft.icon')
|
'koboldai.ico', mimetype='image/vnd.microsoft.icon')
|
||||||
@app.route('/download')
|
@app.route('/download')
|
||||||
|
@require_allowed_ip
|
||||||
def download():
|
def download():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
@@ -4126,6 +4190,8 @@ def execute_outmod():
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
@socketio.on('connect')
|
@socketio.on('connect')
|
||||||
def do_connect():
|
def do_connect():
|
||||||
|
print("Connection Attempt: " + request.remote_addr)
|
||||||
|
print("Allowed?: ", request.remote_addr in allowed_ips)
|
||||||
if request.args.get("rely") == "true":
|
if request.args.get("rely") == "true":
|
||||||
return
|
return
|
||||||
logger.info("Client connected! UI_{}".format(request.args.get('ui')))
|
logger.info("Client connected! UI_{}".format(request.args.get('ui')))
|
||||||
@@ -8021,6 +8087,7 @@ def show_folder_usersripts(data):
|
|||||||
# UI V2 CODE
|
# UI V2 CODE
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@app.route('/new_ui')
|
@app.route('/new_ui')
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def new_ui_index():
|
def new_ui_index():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
@@ -8048,6 +8115,7 @@ def ui2_connect():
|
|||||||
# UI V2 CODE Themes
|
# UI V2 CODE Themes
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@app.route('/themes/<path:path>')
|
@app.route('/themes/<path:path>')
|
||||||
|
#@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def ui2_serve_themes(path):
|
def ui2_serve_themes(path):
|
||||||
return send_from_directory('themes', path)
|
return send_from_directory('themes', path)
|
||||||
@@ -8086,6 +8154,7 @@ def upload_file(data):
|
|||||||
get_files_folders(session['current_folder'])
|
get_files_folders(session['current_folder'])
|
||||||
|
|
||||||
@app.route("/upload_kai_story/<string:file_name>", methods=["POST"])
|
@app.route("/upload_kai_story/<string:file_name>", methods=["POST"])
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_upload_kai_story(file_name: str):
|
def UI_2_upload_kai_story(file_name: str):
|
||||||
|
|
||||||
@@ -8549,6 +8618,7 @@ def directory_to_zip_data(directory: str, overrides: Optional[dict]) -> bytes:
|
|||||||
# Save story to json
|
# Save story to json
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@app.route("/story_download")
|
@app.route("/story_download")
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_download_story():
|
def UI_2_download_story():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
@@ -9011,6 +9081,7 @@ def UI_2_delete_wi_folder(folder):
|
|||||||
# Event triggered when user exports world info folder
|
# Event triggered when user exports world info folder
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@app.route('/export_world_info_folder')
|
@app.route('/export_world_info_folder')
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_export_world_info_folder():
|
def UI_2_export_world_info_folder():
|
||||||
if 'folder' in request.args:
|
if 'folder' in request.args:
|
||||||
@@ -9038,6 +9109,7 @@ def UI_2_upload_world_info_folder(data):
|
|||||||
koboldai_vars.calc_ai_text()
|
koboldai_vars.calc_ai_text()
|
||||||
|
|
||||||
@app.route("/upload_wi", methods=["POST"])
|
@app.route("/upload_wi", methods=["POST"])
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_import_world_info():
|
def UI_2_import_world_info():
|
||||||
wi_data = request.get_json()
|
wi_data = request.get_json()
|
||||||
@@ -9115,6 +9187,7 @@ def UI_2_update_wi_keys(data):
|
|||||||
socketio.emit("world_info_entry", koboldai_vars.worldinfo_v2.world_info[uid], broadcast=True, room="UI_2")
|
socketio.emit("world_info_entry", koboldai_vars.worldinfo_v2.world_info[uid], broadcast=True, room="UI_2")
|
||||||
|
|
||||||
@app.route("/set_wi_image/<int(signed=True):uid>", methods=["POST"])
|
@app.route("/set_wi_image/<int(signed=True):uid>", methods=["POST"])
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_set_wi_image(uid):
|
def UI_2_set_wi_image(uid):
|
||||||
if uid < 0:
|
if uid < 0:
|
||||||
@@ -9146,6 +9219,7 @@ def UI_2_set_wi_image(uid):
|
|||||||
return ":)"
|
return ":)"
|
||||||
|
|
||||||
@app.route("/get_wi_image/<int(signed=True):uid>", methods=["GET"])
|
@app.route("/get_wi_image/<int(signed=True):uid>", methods=["GET"])
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_get_wi_image(uid):
|
def UI_2_get_wi_image(uid):
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
@@ -9157,6 +9231,7 @@ def UI_2_get_wi_image(uid):
|
|||||||
return ":( Couldn't find image", 204
|
return ":( Couldn't find image", 204
|
||||||
|
|
||||||
@app.route("/set_commentator_picture/<int(signed=True):commentator_id>", methods=["POST"])
|
@app.route("/set_commentator_picture/<int(signed=True):commentator_id>", methods=["POST"])
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_set_commentator_image(commentator_id):
|
def UI_2_set_commentator_image(commentator_id):
|
||||||
data = request.get_data()
|
data = request.get_data()
|
||||||
@@ -9165,6 +9240,7 @@ def UI_2_set_commentator_image(commentator_id):
|
|||||||
return ":)"
|
return ":)"
|
||||||
|
|
||||||
@app.route("/image_db.json", methods=["GET"])
|
@app.route("/image_db.json", methods=["GET"])
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_get_image_db():
|
def UI_2_get_image_db():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
@@ -9175,6 +9251,7 @@ def UI_2_get_image_db():
|
|||||||
return jsonify([])
|
return jsonify([])
|
||||||
|
|
||||||
@app.route("/action_composition.json", methods=["GET"])
|
@app.route("/action_composition.json", methods=["GET"])
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_get_action_composition():
|
def UI_2_get_action_composition():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
@@ -9200,6 +9277,7 @@ def UI_2_get_action_composition():
|
|||||||
return jsonify(ret)
|
return jsonify(ret)
|
||||||
|
|
||||||
@app.route("/generated_images/<path:path>")
|
@app.route("/generated_images/<path:path>")
|
||||||
|
@require_allowed_ip
|
||||||
def UI_2_send_generated_images(path):
|
def UI_2_send_generated_images(path):
|
||||||
return send_from_directory(koboldai_vars.save_paths.generated_images, path)
|
return send_from_directory(koboldai_vars.save_paths.generated_images, path)
|
||||||
|
|
||||||
@@ -9509,6 +9587,7 @@ def UI_2_generate_wi(data):
|
|||||||
socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2")
|
socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2")
|
||||||
|
|
||||||
@app.route("/generate_raw", methods=["GET"])
|
@app.route("/generate_raw", methods=["GET"])
|
||||||
|
@require_allowed_ip
|
||||||
def UI_2_generate_raw():
|
def UI_2_generate_raw():
|
||||||
prompt = request.args.get("prompt")
|
prompt = request.args.get("prompt")
|
||||||
|
|
||||||
@@ -10210,6 +10289,7 @@ def UI_2_privacy_mode(data):
|
|||||||
# Genres
|
# Genres
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@app.route("/genre_data.json", methods=["GET"])
|
@app.route("/genre_data.json", methods=["GET"])
|
||||||
|
@require_allowed_ip
|
||||||
def UI_2_get_applicable_genres():
|
def UI_2_get_applicable_genres():
|
||||||
with open("data/genres.json", "r") as file:
|
with open("data/genres.json", "r") as file:
|
||||||
genre_list = json.load(file)
|
genre_list = json.load(file)
|
||||||
@@ -10275,12 +10355,14 @@ def UI_2_get_log(data):
|
|||||||
emit("log_message", web_log_history)
|
emit("log_message", web_log_history)
|
||||||
|
|
||||||
@app.route("/get_log")
|
@app.route("/get_log")
|
||||||
|
@require_allowed_ip
|
||||||
def UI_2_get_log_get():
|
def UI_2_get_log_get():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
return redirect('/api/latest')
|
return redirect('/api/latest')
|
||||||
return {'aiserver_log': web_log_history}
|
return {'aiserver_log': web_log_history}
|
||||||
|
|
||||||
@app.route("/test_match")
|
@app.route("/test_match")
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_test_match():
|
def UI_2_test_match():
|
||||||
koboldai_vars.assign_world_info_to_actions()
|
koboldai_vars.assign_world_info_to_actions()
|
||||||
@@ -10290,6 +10372,7 @@ def UI_2_test_match():
|
|||||||
# Download of the audio file
|
# Download of the audio file
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@app.route("/audio")
|
@app.route("/audio")
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_audio():
|
def UI_2_audio():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
@@ -10316,6 +10399,7 @@ def UI_2_audio():
|
|||||||
# Download of the image for an action
|
# Download of the image for an action
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@app.route("/action_image")
|
@app.route("/action_image")
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def UI_2_action_image():
|
def UI_2_action_image():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
@@ -10375,6 +10459,7 @@ def model_info():
|
|||||||
return {"Model Type": "Read Only", "Model Size": "0", "Model Name": koboldai_vars.model.replace("_", "/")}
|
return {"Model Type": "Read Only", "Model Size": "0", "Model Name": koboldai_vars.model.replace("_", "/")}
|
||||||
|
|
||||||
@app.route("/vars")
|
@app.route("/vars")
|
||||||
|
@require_allowed_ip
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def show_vars():
|
def show_vars():
|
||||||
if args.no_ui:
|
if args.no_ui:
|
||||||
|
Reference in New Issue
Block a user