mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix for custom model names
This commit is contained in:
@@ -1034,7 +1034,7 @@ def getmodelname():
|
||||
if(koboldai_vars.online_model != ''):
|
||||
return(f"{koboldai_vars.model}/{koboldai_vars.online_model}")
|
||||
if(koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||
modelname = os.path.basename(os.path.normpath(koboldai_vars.custmodpth))
|
||||
modelname = os.path.basename(os.path.normpath(model.path))
|
||||
return modelname
|
||||
else:
|
||||
modelname = koboldai_vars.model if koboldai_vars.model is not None else "Read Only"
|
||||
@@ -1687,6 +1687,9 @@ def load_model(model_backend, initial_load=False):
|
||||
model = model_backends[model_backend]
|
||||
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)
|
||||
koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup
|
||||
if koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||
koboldai_vars.model = os.path.basename(os.path.normpath(model.path))
|
||||
logger.info(koboldai_vars.model)
|
||||
logger.debug("Model Type: {}".format(koboldai_vars.model_type))
|
||||
|
||||
# TODO: Convert everywhere to use model.tokenizer
|
||||
|
@@ -41,7 +41,7 @@ class model_backend(HFTorchInferenceModel):
|
||||
|
||||
if self.model_name == "NeoCustom":
|
||||
self.model_name = os.path.basename(
|
||||
os.path.normpath(utils.koboldai_vars.custmodpth)
|
||||
os.path.normpath(self.path)
|
||||
)
|
||||
utils.koboldai_vars.model = self.model_name
|
||||
|
||||
|
@@ -188,6 +188,7 @@ class HFInferenceModel(InferenceModel):
|
||||
self.usegpu = parameters['use_gpu'] if 'use_gpu' in parameters else None
|
||||
self.breakmodel = False
|
||||
self.lazy_load = False
|
||||
logger.info(parameters)
|
||||
self.model_name = parameters['custom_model_name'] if 'custom_model_name' in parameters else parameters['id']
|
||||
self.path = parameters['path'] if 'path' in parameters else None
|
||||
|
||||
|
Reference in New Issue
Block a user