mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' into gui-and-scripting
This commit is contained in:
34
aiserver.py
34
aiserver.py
@ -82,6 +82,7 @@ class vars:
|
||||
submission = "" # Same as above, but after applying input formatting
|
||||
lastctx = "" # The last context submitted to the generator
|
||||
model = "" # Model ID string chosen at startup
|
||||
model_type = "" # Model Type (Automatically taken from the model config)
|
||||
noai = False # Runs the script without starting up the transformers pipeline
|
||||
aibusy = False # Stops submissions while the AI is working
|
||||
max_length = 1024 # Maximum number of tokens to submit per action
|
||||
@ -391,9 +392,29 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
vars.allowsp = True
|
||||
# Test for GPU support
|
||||
import torch
|
||||
|
||||
# Make model path the same as the model name to make this consistent with the other loading method if it isn't a known model type
|
||||
# This code is not just a workaround for below, it is also used to make the behavior consistent with other loading methods - Henk717
|
||||
if(not vars.model in ["NeoCustom", "GPT2Custom"]):
|
||||
vars.custmodpth = vars.model
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
from transformers import AutoConfig
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(vars.custmodpth)
|
||||
except ValueError as e:
|
||||
vars.model_type = "not_found"
|
||||
if(not vars.model_type == "not_found"):
|
||||
vars.model_type = model_config.model_type
|
||||
elif(vars.model == "NeoCustom"):
|
||||
vars.model_type = "gpt_neo"
|
||||
elif(vars.model == "GPT2Custom"):
|
||||
vars.model_type = "gpt2"
|
||||
else:
|
||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
vars.model_type = "gpt_neo"
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = vars.model in ("EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", "EleutherAI/gpt-j-6B", "NeoCustom")
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj")
|
||||
if(args.breakmodel is not None and args.breakmodel):
|
||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr)
|
||||
if(args.breakmodel_layers is not None):
|
||||
@ -870,13 +891,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
# feature yet
|
||||
if("/" not in vars.model and vars.model.lower().startswith("gpt2")):
|
||||
lowmem = {}
|
||||
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
|
||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||
|
||||
if(os.path.isdir(vars.model.replace('/', '_'))):
|
||||
with(maybe_use_float16()):
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
|
||||
|
||||
else:
|
||||
print("Model does not exist locally, attempting to download from Huggingface...")
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
||||
@ -887,7 +909,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
shutil.rmtree("cache/")
|
||||
model.save_pretrained(vars.model.replace('/', '_'))
|
||||
tokenizer.save_pretrained(vars.model.replace('/', '_'))
|
||||
|
||||
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
vars.modeldim = get_hidden_size_from_model(model)
|
||||
@ -1961,8 +1983,8 @@ def loadsettings():
|
||||
# Allow the models to override some settings
|
||||
#==================================================================#
|
||||
def loadmodelsettings():
|
||||
if(path.exists(vars.custmodpth + "/config.json")):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
if(path.exists(vars.custmodpth.replace('/', '_') + "/config.json")):
|
||||
model_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
|
||||
js = json.load(model_config)
|
||||
if("badwordsids" in js):
|
||||
vars.badwordsids = js["badwordsids"]
|
||||
|
Reference in New Issue
Block a user