mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-04-26 07:48:49 +02:00
Update Lua modeltype and model API
This commit is contained in:
parent
c305076cf3
commit
b2def30d9d
23
aiserver.py
23
aiserver.py
@ -82,6 +82,7 @@ class vars:
|
|||||||
submission = "" # Same as above, but after applying input formatting
|
submission = "" # Same as above, but after applying input formatting
|
||||||
lastctx = "" # The last context submitted to the generator
|
lastctx = "" # The last context submitted to the generator
|
||||||
model = "" # Model ID string chosen at startup
|
model = "" # Model ID string chosen at startup
|
||||||
|
model_orig = "" # Original model string before being changed by auto model type detection
|
||||||
model_type = "" # Model Type (Automatically taken from the model config)
|
model_type = "" # Model Type (Automatically taken from the model config)
|
||||||
noai = False # Runs the script without starting up the transformers pipeline
|
noai = False # Runs the script without starting up the transformers pipeline
|
||||||
aibusy = False # Stops submissions while the AI is working
|
aibusy = False # Stops submissions while the AI is working
|
||||||
@ -184,7 +185,7 @@ def getModelSelection():
|
|||||||
while(vars.model == ''):
|
while(vars.model == ''):
|
||||||
modelsel = input("Model #> ")
|
modelsel = input("Model #> ")
|
||||||
if(modelsel.isnumeric() and int(modelsel) > 0 and int(modelsel) <= len(modellist)):
|
if(modelsel.isnumeric() and int(modelsel) > 0 and int(modelsel) <= len(modellist)):
|
||||||
vars.model = modellist[int(modelsel)-1][1]
|
vars.model = vars.model_orig = modellist[int(modelsel)-1][1]
|
||||||
else:
|
else:
|
||||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||||
|
|
||||||
@ -365,7 +366,7 @@ parser.add_argument("--override_rename", action='store_true', help="Renaming sto
|
|||||||
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
vars.model = args.model;
|
vars.model = vars.model_orig = args.model;
|
||||||
|
|
||||||
if args.remote:
|
if args.remote:
|
||||||
vars.remote = True;
|
vars.remote = True;
|
||||||
@ -1386,23 +1387,23 @@ def lua_get_modeltype():
|
|||||||
return "readonly"
|
return "readonly"
|
||||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
if(vars.model in ("Colab", "OAI", "InferKit")):
|
||||||
return "api"
|
return "api"
|
||||||
if(vars.model in ("GPT2Custom", "NeoCustom")):
|
if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
|
||||||
hidden_size = get_hidden_size_from_model(model)
|
hidden_size = get_hidden_size_from_model(model)
|
||||||
if(vars.model in ("gpt2",) or (vars.model == "GPT2Custom" and hidden_size == 768)):
|
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
|
||||||
return "gpt2"
|
return "gpt2"
|
||||||
if(vars.model in ("gpt2-medium",) or (vars.model == "GPT2Custom" and hidden_size == 1024)):
|
if(vars.model in ("gpt2-medium",) or (vars.model_type == "gpt2" and hidden_size == 1024)):
|
||||||
return "gpt2-medium"
|
return "gpt2-medium"
|
||||||
if(vars.model in ("gpt2-large",) or (vars.model == "GPT2Custom" and hidden_size == 1280)):
|
if(vars.model in ("gpt2-large",) or (vars.model_type == "gpt2" and hidden_size == 1280)):
|
||||||
return "gpt2-large"
|
return "gpt2-large"
|
||||||
if(vars.model in ("gpt2-xl",) or (vars.model == "GPT2Custom" and hidden_size == 1600)):
|
if(vars.model in ("gpt2-xl",) or (vars.model_type == "gpt2" and hidden_size == 1600)):
|
||||||
return "gpt2-xl"
|
return "gpt2-xl"
|
||||||
if(vars.model == "NeoCustom" and hidden_size == 768):
|
if(vars.model_type == "gpt_neo" and hidden_size == 768):
|
||||||
return "gpt-neo-125M"
|
return "gpt-neo-125M"
|
||||||
if(vars.model in ("EleutherAI/gpt-neo-1.3B",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
|
if(vars.model in ("EleutherAI/gpt-neo-1.3B",) or (vars.model_type == "gpt_neo" and hidden_size == 2048)):
|
||||||
return "gpt-neo-1.3B"
|
return "gpt-neo-1.3B"
|
||||||
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
|
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)):
|
||||||
return "gpt-neo-2.7B"
|
return "gpt-neo-2.7B"
|
||||||
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
|
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
|
||||||
return "gpt-j-6B"
|
return "gpt-j-6B"
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
@ -1038,7 +1038,7 @@ return function(_python, _bridged)
|
|||||||
---@param t KoboldLib
|
---@param t KoboldLib
|
||||||
---@return string
|
---@return string
|
||||||
function KoboldLib_getters.model(t)
|
function KoboldLib_getters.model(t)
|
||||||
return bridged.vars.model
|
return bridged.vars.model_orig
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param t KoboldLib
|
---@param t KoboldLib
|
||||||
|
Loading…
x
Reference in New Issue
Block a user