Custom Path Load fix
This commit is contained in:
parent
d1a64e25da
commit
da53d7edb3
10
aiserver.py
10
aiserver.py
|
@ -860,15 +860,9 @@ def load_model(use_gpu=True, key='', gpu_layers=None, initial_load=False):
|
||||||
|
|
||||||
# Get the model_type from the config or assume a model type if it isn't present
|
# Get the model_type from the config or assume a model type if it isn't present
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
if(os.path.isdir(vars.custmodpth.replace('/', '_'))):
|
if(os.path.isdir(vars.custmodpth) and vars.custmodpth != ""):
|
||||||
try:
|
try:
|
||||||
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/")
|
model_config = AutoConfig.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
vars.model_type = model_config.model_type
|
|
||||||
except ValueError as e:
|
|
||||||
vars.model_type = "not_found"
|
|
||||||
elif(os.path.isdir("models/{}".format(vars.custmodpth.replace('/', '_'))) and vars.custmodpth != ""):
|
|
||||||
try:
|
|
||||||
model_config = AutoConfig.from_pretrained("models/{}".format(vars.custmodpth.replace('/', '_')), cache_dir="cache/")
|
|
||||||
vars.model_type = model_config.model_type
|
vars.model_type = model_config.model_type
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
vars.model_type = "not_found"
|
vars.model_type = "not_found"
|
||||||
|
|
Loading…
Reference in New Issue