mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add IP whitelisting to --host
This commit is contained in:
91
aiserver.py
91
aiserver.py
@@ -73,6 +73,8 @@ import torch
|
||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification
|
||||
from transformers import __version__ as transformers_version
|
||||
import transformers
|
||||
import ipaddress
|
||||
from functools import wraps
|
||||
try:
|
||||
from transformers.models.opt.modeling_opt import OPTDecoder
|
||||
except:
|
||||
@@ -86,6 +88,9 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
global tpu_mtj_backend
|
||||
global allowed_ips
|
||||
allowed_ips = set() # empty set
|
||||
enable_whitelist = False
|
||||
|
||||
|
||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||
@@ -1469,13 +1474,15 @@ def spRequest(filename):
|
||||
#==================================================================#
|
||||
def general_startup(override_args=None):
|
||||
global args
|
||||
global enable_whitelist
|
||||
global allowed_ips
|
||||
# Parsing Parameters
|
||||
parser = argparse.ArgumentParser(description="KoboldAI Server")
|
||||
parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI for Remote Play")
|
||||
parser.add_argument("--noaimenu", action='store_true', help="Disables the ability to select the AI")
|
||||
parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for Remote Play using Ngrok")
|
||||
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
|
||||
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service")
|
||||
parser.add_argument("--host", type=str, default="", nargs="?", const="", help="Optimizes KoboldAI for LAN Remote Play without using a proxy service. --host opens to all LAN. Enable IP whitelisting by using a comma separated IP list. Supports individual IPs, ranges, and subnets --host 127.0.0.1,127.0.0.2,127.0.0.3,192.168.1.0-192.168.1.255,10.0.0.0/24,etc")
|
||||
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
|
||||
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
|
||||
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
||||
@@ -1525,6 +1532,10 @@ def general_startup(override_args=None):
|
||||
|
||||
utils.args = args
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#load system and user settings
|
||||
for setting in ['user_settings', 'system_settings']:
|
||||
@@ -1602,9 +1613,32 @@ def general_startup(override_args=None):
|
||||
if args.localtunnel:
|
||||
koboldai_vars.host = True;
|
||||
|
||||
if args.host == "":
|
||||
koboldai_vars.host = True
|
||||
args.unblock = True
|
||||
if args.host:
|
||||
koboldai_vars.host = True;
|
||||
args.unblock = True;
|
||||
# This means --host option was submitted without an argument
|
||||
# Enable all LAN IPs (0.0.0.0/0)
|
||||
if args.host != "":
|
||||
# Check if --host option was submitted with an argument
|
||||
# Parse the supplied IP(s) and add them to the allowed IPs list
|
||||
koboldai_vars.host = True
|
||||
args.unblock = True
|
||||
enable_whitelist = True
|
||||
for ip_str in args.host.split(","):
|
||||
if "/" in ip_str:
|
||||
allowed_ips |= set(str(ip) for ip in ipaddress.IPv4Network(ip_str, strict=False).hosts())
|
||||
elif "-" in ip_str:
|
||||
start_ip, end_ip = ip_str.split("-")
|
||||
start_ip_int = int(ipaddress.IPv4Address(start_ip))
|
||||
end_ip_int = int(ipaddress.IPv4Address(end_ip))
|
||||
allowed_ips |= set(str(ipaddress.IPv4Address(ip)) for ip in range(start_ip_int, end_ip_int + 1))
|
||||
else:
|
||||
allowed_ips.add(ip_str.strip())
|
||||
# Sort and print the allowed IPs list
|
||||
allowed_ips = sorted(allowed_ips, key=lambda ip: int(''.join([i.zfill(3) for i in ip.split('.')])))
|
||||
print(f"Allowed IPs: {allowed_ips}")
|
||||
|
||||
|
||||
if args.cpu:
|
||||
koboldai_vars.use_colab_tpu = False
|
||||
@@ -3450,22 +3484,52 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + koboldai_vars.cloudflare_link + format(colors.END))
|
||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + koboldai_vars.cloudflare_link + "/new_ui" + format(colors.END))
|
||||
|
||||
|
||||
# Setup IP Whitelisting
|
||||
# Define a function to check if IP is allowed
|
||||
def is_allowed_ip():
|
||||
global allowed_ips
|
||||
client_ip = request.remote_addr
|
||||
if request.path != '/genre_data.json':
|
||||
print("Connection Attempt: " + request.remote_addr)
|
||||
print("Allowed?: ", request.remote_addr in allowed_ips)
|
||||
return client_ip in allowed_ips
|
||||
|
||||
|
||||
|
||||
# Define a decorator to enforce IP whitelisting
|
||||
def require_allowed_ip(func):
|
||||
@wraps(func)
|
||||
def decorated(*args, **kwargs):
|
||||
|
||||
if enable_whitelist and not is_allowed_ip():
|
||||
return abort(403)
|
||||
return func(*args, **kwargs)
|
||||
return decorated
|
||||
|
||||
|
||||
|
||||
|
||||
# Set up Flask routes
|
||||
@app.route('/')
|
||||
@app.route('/index')
|
||||
@require_allowed_ip
|
||||
def index():
|
||||
if args.no_ui:
|
||||
return redirect('/api/latest')
|
||||
else:
|
||||
return render_template('index.html', hide_ai_menu=args.noaimenu)
|
||||
@app.route('/api', strict_slashes=False)
|
||||
@require_allowed_ip
|
||||
def api():
|
||||
return redirect('/api/latest')
|
||||
@app.route('/favicon.ico')
|
||||
|
||||
def favicon():
|
||||
return send_from_directory(app.root_path,
|
||||
'koboldai.ico', mimetype='image/vnd.microsoft.icon')
|
||||
@app.route('/download')
|
||||
@require_allowed_ip
|
||||
def download():
|
||||
if args.no_ui:
|
||||
raise NotFound()
|
||||
@@ -4126,6 +4190,8 @@ def execute_outmod():
|
||||
#==================================================================#
|
||||
@socketio.on('connect')
|
||||
def do_connect():
|
||||
print("Connection Attempt: " + request.remote_addr)
|
||||
print("Allowed?: ", request.remote_addr in allowed_ips)
|
||||
if request.args.get("rely") == "true":
|
||||
return
|
||||
logger.info("Client connected! UI_{}".format(request.args.get('ui')))
|
||||
@@ -8021,6 +8087,7 @@ def show_folder_usersripts(data):
|
||||
# UI V2 CODE
|
||||
#==================================================================#
|
||||
@app.route('/new_ui')
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def new_ui_index():
|
||||
if args.no_ui:
|
||||
@@ -8048,6 +8115,7 @@ def ui2_connect():
|
||||
# UI V2 CODE Themes
|
||||
#==================================================================#
|
||||
@app.route('/themes/<path:path>')
|
||||
#@require_allowed_ip
|
||||
@logger.catch
|
||||
def ui2_serve_themes(path):
|
||||
return send_from_directory('themes', path)
|
||||
@@ -8086,6 +8154,7 @@ def upload_file(data):
|
||||
get_files_folders(session['current_folder'])
|
||||
|
||||
@app.route("/upload_kai_story/<string:file_name>", methods=["POST"])
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_upload_kai_story(file_name: str):
|
||||
|
||||
@@ -8549,6 +8618,7 @@ def directory_to_zip_data(directory: str, overrides: Optional[dict]) -> bytes:
|
||||
# Save story to json
|
||||
#==================================================================#
|
||||
@app.route("/story_download")
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_download_story():
|
||||
if args.no_ui:
|
||||
@@ -9011,6 +9081,7 @@ def UI_2_delete_wi_folder(folder):
|
||||
# Event triggered when user exports world info folder
|
||||
#==================================================================#
|
||||
@app.route('/export_world_info_folder')
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_export_world_info_folder():
|
||||
if 'folder' in request.args:
|
||||
@@ -9038,6 +9109,7 @@ def UI_2_upload_world_info_folder(data):
|
||||
koboldai_vars.calc_ai_text()
|
||||
|
||||
@app.route("/upload_wi", methods=["POST"])
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_import_world_info():
|
||||
wi_data = request.get_json()
|
||||
@@ -9115,6 +9187,7 @@ def UI_2_update_wi_keys(data):
|
||||
socketio.emit("world_info_entry", koboldai_vars.worldinfo_v2.world_info[uid], broadcast=True, room="UI_2")
|
||||
|
||||
@app.route("/set_wi_image/<int(signed=True):uid>", methods=["POST"])
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_set_wi_image(uid):
|
||||
if uid < 0:
|
||||
@@ -9146,6 +9219,7 @@ def UI_2_set_wi_image(uid):
|
||||
return ":)"
|
||||
|
||||
@app.route("/get_wi_image/<int(signed=True):uid>", methods=["GET"])
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_get_wi_image(uid):
|
||||
if args.no_ui:
|
||||
@@ -9157,6 +9231,7 @@ def UI_2_get_wi_image(uid):
|
||||
return ":( Couldn't find image", 204
|
||||
|
||||
@app.route("/set_commentator_picture/<int(signed=True):commentator_id>", methods=["POST"])
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_set_commentator_image(commentator_id):
|
||||
data = request.get_data()
|
||||
@@ -9165,6 +9240,7 @@ def UI_2_set_commentator_image(commentator_id):
|
||||
return ":)"
|
||||
|
||||
@app.route("/image_db.json", methods=["GET"])
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_get_image_db():
|
||||
if args.no_ui:
|
||||
@@ -9175,6 +9251,7 @@ def UI_2_get_image_db():
|
||||
return jsonify([])
|
||||
|
||||
@app.route("/action_composition.json", methods=["GET"])
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_get_action_composition():
|
||||
if args.no_ui:
|
||||
@@ -9200,6 +9277,7 @@ def UI_2_get_action_composition():
|
||||
return jsonify(ret)
|
||||
|
||||
@app.route("/generated_images/<path:path>")
|
||||
@require_allowed_ip
|
||||
def UI_2_send_generated_images(path):
|
||||
return send_from_directory(koboldai_vars.save_paths.generated_images, path)
|
||||
|
||||
@@ -9509,6 +9587,7 @@ def UI_2_generate_wi(data):
|
||||
socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2")
|
||||
|
||||
@app.route("/generate_raw", methods=["GET"])
|
||||
@require_allowed_ip
|
||||
def UI_2_generate_raw():
|
||||
prompt = request.args.get("prompt")
|
||||
|
||||
@@ -10210,6 +10289,7 @@ def UI_2_privacy_mode(data):
|
||||
# Genres
|
||||
#==================================================================#
|
||||
@app.route("/genre_data.json", methods=["GET"])
|
||||
@require_allowed_ip
|
||||
def UI_2_get_applicable_genres():
|
||||
with open("data/genres.json", "r") as file:
|
||||
genre_list = json.load(file)
|
||||
@@ -10275,12 +10355,14 @@ def UI_2_get_log(data):
|
||||
emit("log_message", web_log_history)
|
||||
|
||||
@app.route("/get_log")
|
||||
@require_allowed_ip
|
||||
def UI_2_get_log_get():
|
||||
if args.no_ui:
|
||||
return redirect('/api/latest')
|
||||
return {'aiserver_log': web_log_history}
|
||||
|
||||
@app.route("/test_match")
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_test_match():
|
||||
koboldai_vars.assign_world_info_to_actions()
|
||||
@@ -10290,6 +10372,7 @@ def UI_2_test_match():
|
||||
# Download of the audio file
|
||||
#==================================================================#
|
||||
@app.route("/audio")
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_audio():
|
||||
if args.no_ui:
|
||||
@@ -10316,6 +10399,7 @@ def UI_2_audio():
|
||||
# Download of the image for an action
|
||||
#==================================================================#
|
||||
@app.route("/action_image")
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def UI_2_action_image():
|
||||
if args.no_ui:
|
||||
@@ -10375,6 +10459,7 @@ def model_info():
|
||||
return {"Model Type": "Read Only", "Model Size": "0", "Model Name": koboldai_vars.model.replace("_", "/")}
|
||||
|
||||
@app.route("/vars")
|
||||
@require_allowed_ip
|
||||
@logger.catch
|
||||
def show_vars():
|
||||
if args.no_ui:
|
||||
|
Reference in New Issue
Block a user