Add IP whitelisting to --host

This commit is contained in:
YellowRoseCx
2023-04-05 21:23:24 -05:00
committed by GitHub
parent 80e4b9e536
commit 71e5d23a5b

View File

@@ -73,6 +73,8 @@ import torch
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
import transformers import transformers
import ipaddress
from functools import wraps
try: try:
from transformers.models.opt.modeling_opt import OPTDecoder from transformers.models.opt.modeling_opt import OPTDecoder
except: except:
@@ -86,6 +88,9 @@ from PIL import Image
from io import BytesIO from io import BytesIO
global tpu_mtj_backend global tpu_mtj_backend
global allowed_ips
allowed_ips = set() # empty set
enable_whitelist = False
if lupa.LUA_VERSION[:2] != (5, 4): if lupa.LUA_VERSION[:2] != (5, 4):
@@ -1469,13 +1474,15 @@ def spRequest(filename):
#==================================================================# #==================================================================#
def general_startup(override_args=None): def general_startup(override_args=None):
global args global args
global enable_whitelist
global allowed_ips
# Parsing Parameters # Parsing Parameters
parser = argparse.ArgumentParser(description="KoboldAI Server") parser = argparse.ArgumentParser(description="KoboldAI Server")
parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI for Remote Play") parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI for Remote Play")
parser.add_argument("--noaimenu", action='store_true', help="Disables the ability to select the AI") parser.add_argument("--noaimenu", action='store_true', help="Disables the ability to select the AI")
parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for Remote Play using Ngrok") parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for Remote Play using Ngrok")
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel") parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service") parser.add_argument("--host", type=str, default="", nargs="?", const="", help="Optimizes KoboldAI for LAN Remote Play without using a proxy service. --host opens to all LAN. Enable IP whitelisting by using a comma separated IP list. Supports individual IPs, ranges, and subnets --host 127.0.0.1,127.0.0.2,127.0.0.3,192.168.1.0-192.168.1.255,10.0.0.0/24,etc")
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable") parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)") parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
parser.add_argument("--model", help="Specify the Model Type to skip the Menu") parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
@@ -1525,6 +1532,10 @@ def general_startup(override_args=None):
utils.args = args utils.args = args
#load system and user settings #load system and user settings
for setting in ['user_settings', 'system_settings']: for setting in ['user_settings', 'system_settings']:
@@ -1602,9 +1613,32 @@ def general_startup(override_args=None):
if args.localtunnel: if args.localtunnel:
koboldai_vars.host = True; koboldai_vars.host = True;
if args.host == "":
koboldai_vars.host = True
args.unblock = True
if args.host: if args.host:
koboldai_vars.host = True; # This means --host option was submitted without an argument
args.unblock = True; # Enable all LAN IPs (0.0.0.0/0)
if args.host != "":
# Check if --host option was submitted with an argument
# Parse the supplied IP(s) and add them to the allowed IPs list
koboldai_vars.host = True
args.unblock = True
enable_whitelist = True
for ip_str in args.host.split(","):
if "/" in ip_str:
allowed_ips |= set(str(ip) for ip in ipaddress.IPv4Network(ip_str, strict=False).hosts())
elif "-" in ip_str:
start_ip, end_ip = ip_str.split("-")
start_ip_int = int(ipaddress.IPv4Address(start_ip))
end_ip_int = int(ipaddress.IPv4Address(end_ip))
allowed_ips |= set(str(ipaddress.IPv4Address(ip)) for ip in range(start_ip_int, end_ip_int + 1))
else:
allowed_ips.add(ip_str.strip())
# Sort and print the allowed IPs list
allowed_ips = sorted(allowed_ips, key=lambda ip: int(''.join([i.zfill(3) for i in ip.split('.')])))
print(f"Allowed IPs: {allowed_ips}")
if args.cpu: if args.cpu:
koboldai_vars.use_colab_tpu = False koboldai_vars.use_colab_tpu = False
@@ -3450,22 +3484,52 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + koboldai_vars.cloudflare_link + format(colors.END)) print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + koboldai_vars.cloudflare_link + format(colors.END))
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + koboldai_vars.cloudflare_link + "/new_ui" + format(colors.END)) print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + koboldai_vars.cloudflare_link + "/new_ui" + format(colors.END))
# Setup IP Whitelisting
# Define a function to check if IP is allowed
def is_allowed_ip():
global allowed_ips
client_ip = request.remote_addr
if request.path != '/genre_data.json':
print("Connection Attempt: " + request.remote_addr)
print("Allowed?: ", request.remote_addr in allowed_ips)
return client_ip in allowed_ips
# Define a decorator to enforce IP whitelisting
def require_allowed_ip(func):
@wraps(func)
def decorated(*args, **kwargs):
if enable_whitelist and not is_allowed_ip():
return abort(403)
return func(*args, **kwargs)
return decorated
# Set up Flask routes # Set up Flask routes
@app.route('/') @app.route('/')
@app.route('/index') @app.route('/index')
@require_allowed_ip
def index(): def index():
if args.no_ui: if args.no_ui:
return redirect('/api/latest') return redirect('/api/latest')
else: else:
return render_template('index.html', hide_ai_menu=args.noaimenu) return render_template('index.html', hide_ai_menu=args.noaimenu)
@app.route('/api', strict_slashes=False) @app.route('/api', strict_slashes=False)
@require_allowed_ip
def api(): def api():
return redirect('/api/latest') return redirect('/api/latest')
@app.route('/favicon.ico') @app.route('/favicon.ico')
def favicon(): def favicon():
return send_from_directory(app.root_path, return send_from_directory(app.root_path,
'koboldai.ico', mimetype='image/vnd.microsoft.icon') 'koboldai.ico', mimetype='image/vnd.microsoft.icon')
@app.route('/download') @app.route('/download')
@require_allowed_ip
def download(): def download():
if args.no_ui: if args.no_ui:
raise NotFound() raise NotFound()
@@ -4126,6 +4190,8 @@ def execute_outmod():
#==================================================================# #==================================================================#
@socketio.on('connect') @socketio.on('connect')
def do_connect(): def do_connect():
print("Connection Attempt: " + request.remote_addr)
print("Allowed?: ", request.remote_addr in allowed_ips)
if request.args.get("rely") == "true": if request.args.get("rely") == "true":
return return
logger.info("Client connected! UI_{}".format(request.args.get('ui'))) logger.info("Client connected! UI_{}".format(request.args.get('ui')))
@@ -8021,6 +8087,7 @@ def show_folder_usersripts(data):
# UI V2 CODE # UI V2 CODE
#==================================================================# #==================================================================#
@app.route('/new_ui') @app.route('/new_ui')
@require_allowed_ip
@logger.catch @logger.catch
def new_ui_index(): def new_ui_index():
if args.no_ui: if args.no_ui:
@@ -8048,6 +8115,7 @@ def ui2_connect():
# UI V2 CODE Themes # UI V2 CODE Themes
#==================================================================# #==================================================================#
@app.route('/themes/<path:path>') @app.route('/themes/<path:path>')
#@require_allowed_ip
@logger.catch @logger.catch
def ui2_serve_themes(path): def ui2_serve_themes(path):
return send_from_directory('themes', path) return send_from_directory('themes', path)
@@ -8086,6 +8154,7 @@ def upload_file(data):
get_files_folders(session['current_folder']) get_files_folders(session['current_folder'])
@app.route("/upload_kai_story/<string:file_name>", methods=["POST"]) @app.route("/upload_kai_story/<string:file_name>", methods=["POST"])
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_upload_kai_story(file_name: str): def UI_2_upload_kai_story(file_name: str):
@@ -8549,6 +8618,7 @@ def directory_to_zip_data(directory: str, overrides: Optional[dict]) -> bytes:
# Save story to json # Save story to json
#==================================================================# #==================================================================#
@app.route("/story_download") @app.route("/story_download")
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_download_story(): def UI_2_download_story():
if args.no_ui: if args.no_ui:
@@ -9011,6 +9081,7 @@ def UI_2_delete_wi_folder(folder):
# Event triggered when user exports world info folder # Event triggered when user exports world info folder
#==================================================================# #==================================================================#
@app.route('/export_world_info_folder') @app.route('/export_world_info_folder')
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_export_world_info_folder(): def UI_2_export_world_info_folder():
if 'folder' in request.args: if 'folder' in request.args:
@@ -9038,6 +9109,7 @@ def UI_2_upload_world_info_folder(data):
koboldai_vars.calc_ai_text() koboldai_vars.calc_ai_text()
@app.route("/upload_wi", methods=["POST"]) @app.route("/upload_wi", methods=["POST"])
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_import_world_info(): def UI_2_import_world_info():
wi_data = request.get_json() wi_data = request.get_json()
@@ -9115,6 +9187,7 @@ def UI_2_update_wi_keys(data):
socketio.emit("world_info_entry", koboldai_vars.worldinfo_v2.world_info[uid], broadcast=True, room="UI_2") socketio.emit("world_info_entry", koboldai_vars.worldinfo_v2.world_info[uid], broadcast=True, room="UI_2")
@app.route("/set_wi_image/<int(signed=True):uid>", methods=["POST"]) @app.route("/set_wi_image/<int(signed=True):uid>", methods=["POST"])
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_set_wi_image(uid): def UI_2_set_wi_image(uid):
if uid < 0: if uid < 0:
@@ -9146,6 +9219,7 @@ def UI_2_set_wi_image(uid):
return ":)" return ":)"
@app.route("/get_wi_image/<int(signed=True):uid>", methods=["GET"]) @app.route("/get_wi_image/<int(signed=True):uid>", methods=["GET"])
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_get_wi_image(uid): def UI_2_get_wi_image(uid):
if args.no_ui: if args.no_ui:
@@ -9157,6 +9231,7 @@ def UI_2_get_wi_image(uid):
return ":( Couldn't find image", 204 return ":( Couldn't find image", 204
@app.route("/set_commentator_picture/<int(signed=True):commentator_id>", methods=["POST"]) @app.route("/set_commentator_picture/<int(signed=True):commentator_id>", methods=["POST"])
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_set_commentator_image(commentator_id): def UI_2_set_commentator_image(commentator_id):
data = request.get_data() data = request.get_data()
@@ -9165,6 +9240,7 @@ def UI_2_set_commentator_image(commentator_id):
return ":)" return ":)"
@app.route("/image_db.json", methods=["GET"]) @app.route("/image_db.json", methods=["GET"])
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_get_image_db(): def UI_2_get_image_db():
if args.no_ui: if args.no_ui:
@@ -9175,6 +9251,7 @@ def UI_2_get_image_db():
return jsonify([]) return jsonify([])
@app.route("/action_composition.json", methods=["GET"]) @app.route("/action_composition.json", methods=["GET"])
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_get_action_composition(): def UI_2_get_action_composition():
if args.no_ui: if args.no_ui:
@@ -9200,6 +9277,7 @@ def UI_2_get_action_composition():
return jsonify(ret) return jsonify(ret)
@app.route("/generated_images/<path:path>") @app.route("/generated_images/<path:path>")
@require_allowed_ip
def UI_2_send_generated_images(path): def UI_2_send_generated_images(path):
return send_from_directory(koboldai_vars.save_paths.generated_images, path) return send_from_directory(koboldai_vars.save_paths.generated_images, path)
@@ -9509,6 +9587,7 @@ def UI_2_generate_wi(data):
socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2") socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2")
@app.route("/generate_raw", methods=["GET"]) @app.route("/generate_raw", methods=["GET"])
@require_allowed_ip
def UI_2_generate_raw(): def UI_2_generate_raw():
prompt = request.args.get("prompt") prompt = request.args.get("prompt")
@@ -10210,6 +10289,7 @@ def UI_2_privacy_mode(data):
# Genres # Genres
#==================================================================# #==================================================================#
@app.route("/genre_data.json", methods=["GET"]) @app.route("/genre_data.json", methods=["GET"])
@require_allowed_ip
def UI_2_get_applicable_genres(): def UI_2_get_applicable_genres():
with open("data/genres.json", "r") as file: with open("data/genres.json", "r") as file:
genre_list = json.load(file) genre_list = json.load(file)
@@ -10275,12 +10355,14 @@ def UI_2_get_log(data):
emit("log_message", web_log_history) emit("log_message", web_log_history)
@app.route("/get_log") @app.route("/get_log")
@require_allowed_ip
def UI_2_get_log_get(): def UI_2_get_log_get():
if args.no_ui: if args.no_ui:
return redirect('/api/latest') return redirect('/api/latest')
return {'aiserver_log': web_log_history} return {'aiserver_log': web_log_history}
@app.route("/test_match") @app.route("/test_match")
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_test_match(): def UI_2_test_match():
koboldai_vars.assign_world_info_to_actions() koboldai_vars.assign_world_info_to_actions()
@@ -10290,6 +10372,7 @@ def UI_2_test_match():
# Download of the audio file # Download of the audio file
#==================================================================# #==================================================================#
@app.route("/audio") @app.route("/audio")
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_audio(): def UI_2_audio():
if args.no_ui: if args.no_ui:
@@ -10316,6 +10399,7 @@ def UI_2_audio():
# Download of the image for an action # Download of the image for an action
#==================================================================# #==================================================================#
@app.route("/action_image") @app.route("/action_image")
@require_allowed_ip
@logger.catch @logger.catch
def UI_2_action_image(): def UI_2_action_image():
if args.no_ui: if args.no_ui:
@@ -10375,6 +10459,7 @@ def model_info():
return {"Model Type": "Read Only", "Model Size": "0", "Model Name": koboldai_vars.model.replace("_", "/")} return {"Model Type": "Read Only", "Model Size": "0", "Model Name": koboldai_vars.model.replace("_", "/")}
@app.route("/vars") @app.route("/vars")
@require_allowed_ip
@logger.catch @logger.catch
def show_vars(): def show_vars():
if args.no_ui: if args.no_ui: