mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 12:10:49 +01:00
updated
This commit is contained in:
parent
42e04afc83
commit
496ef1472d
102
aiserver.py
102
aiserver.py
@ -319,6 +319,7 @@ class vars:
|
|||||||
colaburl = "" # Ngrok url for Google Colab mode
|
colaburl = "" # Ngrok url for Google Colab mode
|
||||||
apikey = "" # API key to use for InferKit API calls
|
apikey = "" # API key to use for InferKit API calls
|
||||||
oaiapikey = "" # API key to use for OpenAI API calls
|
oaiapikey = "" # API key to use for OpenAI API calls
|
||||||
|
cluster_requested_models = [] # The models which we allow to generate during cluster mode
|
||||||
savedir = getcwd()+"\\stories"
|
savedir = getcwd()+"\\stories"
|
||||||
hascuda = False # Whether torch has detected CUDA on the system
|
hascuda = False # Whether torch has detected CUDA on the system
|
||||||
usegpu = False # Whether to launch pipeline with GPU support
|
usegpu = False # Whether to launch pipeline with GPU support
|
||||||
@ -1288,6 +1289,8 @@ def general_startup(override_args=None):
|
|||||||
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
|
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
|
||||||
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
||||||
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
|
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
|
||||||
|
parser.add_argument("--apikey", help="Specify the API key to use for online services")
|
||||||
|
parser.add_argument("--req_model", type=str, action='append', required=False, help="Which models which we allow to generate for us during cluster mode. Can be specified multiple times.")
|
||||||
parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)")
|
parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)")
|
||||||
parser.add_argument("--cpu", action='store_true', help="By default unattended launches are on the GPU use this option to force CPU usage.")
|
parser.add_argument("--cpu", action='store_true', help="By default unattended launches are on the GPU use this option to force CPU usage.")
|
||||||
parser.add_argument("--breakmodel", action='store_true', help=argparse.SUPPRESS)
|
parser.add_argument("--breakmodel", action='store_true', help=argparse.SUPPRESS)
|
||||||
@ -1336,6 +1339,11 @@ def general_startup(override_args=None):
|
|||||||
vars.model = args.model;
|
vars.model = args.model;
|
||||||
vars.revision = args.revision
|
vars.revision = args.revision
|
||||||
|
|
||||||
|
if args.apikey:
|
||||||
|
vars.apikey = args.apikey
|
||||||
|
if args.req_model:
|
||||||
|
vars.cluster_requested_models = args.req_model
|
||||||
|
|
||||||
if args.colab:
|
if args.colab:
|
||||||
args.remote = True;
|
args.remote = True;
|
||||||
args.override_rename = True;
|
args.override_rename = True;
|
||||||
@ -3979,11 +3987,19 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||||||
while(True):
|
while(True):
|
||||||
set_aibusy(1)
|
set_aibusy(1)
|
||||||
|
|
||||||
if(vars.model == "API"):
|
if(vars.model in ["API","CLUSTER"]):
|
||||||
global tokenizer
|
global tokenizer
|
||||||
tokenizer_id = requests.get(
|
if vars.model == "API":
|
||||||
vars.colaburl[:-8] + "/api/v1/model",
|
tokenizer_id = requests.get(
|
||||||
).json()["result"]
|
vars.colaburl[:-8] + "/api/v1/model",
|
||||||
|
).json()["result"]
|
||||||
|
elif len(vars.cluster_requested_models) >= 1:
|
||||||
|
# If the player has requested one or more models, we use the first one for the tokenizer
|
||||||
|
tokenizer_id = vars.cluster_requested_models[0]
|
||||||
|
# The cluster can return any number of possible models for each gen, but this happens after this step
|
||||||
|
# So at this point, this is unknown
|
||||||
|
else:
|
||||||
|
tokenizer_id = ""
|
||||||
if tokenizer_id != vars.api_tokenizer_id:
|
if tokenizer_id != vars.api_tokenizer_id:
|
||||||
try:
|
try:
|
||||||
if(os.path.isdir(tokenizer_id)):
|
if(os.path.isdir(tokenizer_id)):
|
||||||
@ -5024,6 +5040,84 @@ def sendtoapi(txt, min, max):
|
|||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Send transformers-style request to KoboldAI Cluster
|
||||||
|
#==================================================================#
|
||||||
|
def sendtocluster(txt, min, max):
|
||||||
|
# Log request to console
|
||||||
|
if not vars.quiet:
|
||||||
|
print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END))
|
||||||
|
|
||||||
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
vars.lastctx = txt
|
||||||
|
|
||||||
|
# Build request JSON data
|
||||||
|
reqdata = {
|
||||||
|
'max_length': max - min + 1,
|
||||||
|
'max_context_length': vars.max_length,
|
||||||
|
'rep_pen': vars.rep_pen,
|
||||||
|
'rep_pen_slope': vars.rep_pen_slope,
|
||||||
|
'rep_pen_range': vars.rep_pen_range,
|
||||||
|
'temperature': vars.temp,
|
||||||
|
'top_p': vars.top_p,
|
||||||
|
'top_k': vars.top_k,
|
||||||
|
'top_a': vars.top_a,
|
||||||
|
'tfs': vars.tfs,
|
||||||
|
'typical': vars.typical,
|
||||||
|
'n': vars.numseqs,
|
||||||
|
}
|
||||||
|
cluster_metadata = {
|
||||||
|
'prompt': txt,
|
||||||
|
'params': reqdata,
|
||||||
|
'username': vars.apikey,
|
||||||
|
'models': vars.cluster_requested_models,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
req = requests.post(
|
||||||
|
vars.colaburl[:-8] + "/generate/sync",
|
||||||
|
json=cluster_metadata,
|
||||||
|
)
|
||||||
|
js = req.json()
|
||||||
|
if(req.status_code == 503):
|
||||||
|
errmsg = "KoboldAI API Error: No available KoboldAI servers found in cluster to fulfil this request using the selected models and requested lengths."
|
||||||
|
print("{0}{1}{2}".format(colors.RED, json.dumps(js, indent=2), colors.END))
|
||||||
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
if(req.status_code != 200):
|
||||||
|
errmsg = "KoboldAI API Error: Failed to get a reply from the server. Please check the console."
|
||||||
|
print("{0}{1}{2}".format(colors.RED, json.dumps(js, indent=2), colors.END))
|
||||||
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
genout = js
|
||||||
|
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.outputs[i+1] = genout[i]
|
||||||
|
|
||||||
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
genout.append(vars.lua_koboldbridge.outputs[i+1])
|
||||||
|
assert type(genout[-1]) is str
|
||||||
|
|
||||||
|
if(len(genout) == 1):
|
||||||
|
genresult(genout[0])
|
||||||
|
else:
|
||||||
|
# Convert torch output format to transformers
|
||||||
|
seqs = []
|
||||||
|
for seq in genout:
|
||||||
|
seqs.append({"generated_text": seq})
|
||||||
|
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||||
|
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||||
|
else:
|
||||||
|
genselect(genout)
|
||||||
|
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Send text to TPU mesh transformer backend
|
# Send text to TPU mesh transformer backend
|
||||||
|
Loading…
x
Reference in New Issue
Block a user