diff --git a/aiserver.py b/aiserver.py index 7cbbc8ac..c22be16c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -73,6 +73,8 @@ import torch from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification from transformers import __version__ as transformers_version import transformers +import ipaddress +from functools import wraps try: from transformers.models.opt.modeling_opt import OPTDecoder except: @@ -86,6 +88,9 @@ from PIL import Image from io import BytesIO global tpu_mtj_backend +global allowed_ips +allowed_ips = set() # empty set +enable_whitelist = False if lupa.LUA_VERSION[:2] != (5, 4): @@ -1469,13 +1474,15 @@ def spRequest(filename): #==================================================================# def general_startup(override_args=None): global args + global enable_whitelist + global allowed_ips # Parsing Parameters parser = argparse.ArgumentParser(description="KoboldAI Server") parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI for Remote Play") parser.add_argument("--noaimenu", action='store_true', help="Disables the ability to select the AI") parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for Remote Play using Ngrok") parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel") - parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service") + parser.add_argument("--host", type=str, default="", nargs="?", const="", help="Optimizes KoboldAI for LAN Remote Play without using a proxy service. --host opens to all LAN. Enable IP whitelisting by using a comma separated IP list. Supports individual IPs, ranges, and subnets --host 127.0.0.1,127.0.0.2,127.0.0.3,192.168.1.0-192.168.1.255,10.0.0.0/24,etc") parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable") parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)") parser.add_argument("--model", help="Specify the Model Type to skip the Menu") @@ -1525,6 +1532,10 @@ def general_startup(override_args=None): utils.args = args + + + + #load system and user settings for setting in ['user_settings', 'system_settings']: @@ -1602,9 +1613,32 @@ def general_startup(override_args=None): if args.localtunnel: koboldai_vars.host = True; + if args.host == "": + koboldai_vars.host = True + args.unblock = True if args.host: - koboldai_vars.host = True; - args.unblock = True; + # This means --host option was submitted without an argument + # Enable all LAN IPs (0.0.0.0/0) + if args.host != "": + # Check if --host option was submitted with an argument + # Parse the supplied IP(s) and add them to the allowed IPs list + koboldai_vars.host = True + args.unblock = True + enable_whitelist = True + for ip_str in args.host.split(","): + if "/" in ip_str: + allowed_ips |= set(str(ip) for ip in ipaddress.IPv4Network(ip_str, strict=False).hosts()) + elif "-" in ip_str: + start_ip, end_ip = ip_str.split("-") + start_ip_int = int(ipaddress.IPv4Address(start_ip)) + end_ip_int = int(ipaddress.IPv4Address(end_ip)) + allowed_ips |= set(str(ipaddress.IPv4Address(ip)) for ip in range(start_ip_int, end_ip_int + 1)) + else: + allowed_ips.add(ip_str.strip()) + # Sort and print the allowed IPs list + allowed_ips = sorted(allowed_ips, key=lambda ip: int(''.join([i.zfill(3) for i in ip.split('.')]))) + print(f"Allowed IPs: {allowed_ips}") + if args.cpu: koboldai_vars.use_colab_tpu = False @@ -3450,22 +3484,52 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + koboldai_vars.cloudflare_link + format(colors.END)) print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + koboldai_vars.cloudflare_link + "/new_ui" + format(colors.END)) + +# Setup IP Whitelisting +# Define a function to check if IP is allowed +def is_allowed_ip(): + global allowed_ips + client_ip = request.remote_addr + if request.path != '/genre_data.json': + print("Connection Attempt: " + request.remote_addr) + print("Allowed?: ", request.remote_addr in allowed_ips) + return client_ip in allowed_ips + + + +# Define a decorator to enforce IP whitelisting +def require_allowed_ip(func): + @wraps(func) + def decorated(*args, **kwargs): + + if enable_whitelist and not is_allowed_ip(): + return abort(403) + return func(*args, **kwargs) + return decorated + + + + # Set up Flask routes @app.route('/') @app.route('/index') +@require_allowed_ip def index(): if args.no_ui: return redirect('/api/latest') else: return render_template('index.html', hide_ai_menu=args.noaimenu) @app.route('/api', strict_slashes=False) +@require_allowed_ip def api(): return redirect('/api/latest') @app.route('/favicon.ico') + def favicon(): return send_from_directory(app.root_path, 'koboldai.ico', mimetype='image/vnd.microsoft.icon') @app.route('/download') +@require_allowed_ip def download(): if args.no_ui: raise NotFound() @@ -4126,6 +4190,8 @@ def execute_outmod(): #==================================================================# @socketio.on('connect') def do_connect(): + print("Connection Attempt: " + request.remote_addr) + print("Allowed?: ", request.remote_addr in allowed_ips) if request.args.get("rely") == "true": return logger.info("Client connected! UI_{}".format(request.args.get('ui'))) @@ -8021,6 +8087,7 @@ def show_folder_usersripts(data): # UI V2 CODE #==================================================================# @app.route('/new_ui') +@require_allowed_ip @logger.catch def new_ui_index(): if args.no_ui: @@ -8048,6 +8115,7 @@ def ui2_connect(): # UI V2 CODE Themes #==================================================================# @app.route('/themes/') +#@require_allowed_ip @logger.catch def ui2_serve_themes(path): return send_from_directory('themes', path) @@ -8086,6 +8154,7 @@ def upload_file(data): get_files_folders(session['current_folder']) @app.route("/upload_kai_story/", methods=["POST"]) +@require_allowed_ip @logger.catch def UI_2_upload_kai_story(file_name: str): @@ -8549,6 +8618,7 @@ def directory_to_zip_data(directory: str, overrides: Optional[dict]) -> bytes: # Save story to json #==================================================================# @app.route("/story_download") +@require_allowed_ip @logger.catch def UI_2_download_story(): if args.no_ui: @@ -9011,6 +9081,7 @@ def UI_2_delete_wi_folder(folder): # Event triggered when user exports world info folder #==================================================================# @app.route('/export_world_info_folder') +@require_allowed_ip @logger.catch def UI_2_export_world_info_folder(): if 'folder' in request.args: @@ -9038,6 +9109,7 @@ def UI_2_upload_world_info_folder(data): koboldai_vars.calc_ai_text() @app.route("/upload_wi", methods=["POST"]) +@require_allowed_ip @logger.catch def UI_2_import_world_info(): wi_data = request.get_json() @@ -9115,6 +9187,7 @@ def UI_2_update_wi_keys(data): socketio.emit("world_info_entry", koboldai_vars.worldinfo_v2.world_info[uid], broadcast=True, room="UI_2") @app.route("/set_wi_image/", methods=["POST"]) +@require_allowed_ip @logger.catch def UI_2_set_wi_image(uid): if uid < 0: @@ -9146,6 +9219,7 @@ def UI_2_set_wi_image(uid): return ":)" @app.route("/get_wi_image/", methods=["GET"]) +@require_allowed_ip @logger.catch def UI_2_get_wi_image(uid): if args.no_ui: @@ -9157,6 +9231,7 @@ def UI_2_get_wi_image(uid): return ":( Couldn't find image", 204 @app.route("/set_commentator_picture/", methods=["POST"]) +@require_allowed_ip @logger.catch def UI_2_set_commentator_image(commentator_id): data = request.get_data() @@ -9165,6 +9240,7 @@ def UI_2_set_commentator_image(commentator_id): return ":)" @app.route("/image_db.json", methods=["GET"]) +@require_allowed_ip @logger.catch def UI_2_get_image_db(): if args.no_ui: @@ -9175,6 +9251,7 @@ def UI_2_get_image_db(): return jsonify([]) @app.route("/action_composition.json", methods=["GET"]) +@require_allowed_ip @logger.catch def UI_2_get_action_composition(): if args.no_ui: @@ -9200,6 +9277,7 @@ def UI_2_get_action_composition(): return jsonify(ret) @app.route("/generated_images/") +@require_allowed_ip def UI_2_send_generated_images(path): return send_from_directory(koboldai_vars.save_paths.generated_images, path) @@ -9509,6 +9587,7 @@ def UI_2_generate_wi(data): socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2") @app.route("/generate_raw", methods=["GET"]) +@require_allowed_ip def UI_2_generate_raw(): prompt = request.args.get("prompt") @@ -10210,6 +10289,7 @@ def UI_2_privacy_mode(data): # Genres #==================================================================# @app.route("/genre_data.json", methods=["GET"]) +@require_allowed_ip def UI_2_get_applicable_genres(): with open("data/genres.json", "r") as file: genre_list = json.load(file) @@ -10275,12 +10355,14 @@ def UI_2_get_log(data): emit("log_message", web_log_history) @app.route("/get_log") +@require_allowed_ip def UI_2_get_log_get(): if args.no_ui: return redirect('/api/latest') return {'aiserver_log': web_log_history} @app.route("/test_match") +@require_allowed_ip @logger.catch def UI_2_test_match(): koboldai_vars.assign_world_info_to_actions() @@ -10290,6 +10372,7 @@ def UI_2_test_match(): # Download of the audio file #==================================================================# @app.route("/audio") +@require_allowed_ip @logger.catch def UI_2_audio(): if args.no_ui: @@ -10316,6 +10399,7 @@ def UI_2_audio(): # Download of the image for an action #==================================================================# @app.route("/action_image") +@require_allowed_ip @logger.catch def UI_2_action_image(): if args.no_ui: @@ -10375,6 +10459,7 @@ def model_info(): return {"Model Type": "Read Only", "Model Size": "0", "Model Name": koboldai_vars.model.replace("_", "/")} @app.route("/vars") +@require_allowed_ip @logger.catch def show_vars(): if args.no_ui: