Replace model path differently

The path correction was applied to soon and broke online loading, applying the replace where it is relevant instead.
This commit is contained in:
henk717 2021-12-23 03:05:53 +01:00
parent 4ff1a6e940
commit a2d8347939
1 changed files with 3 additions and 3 deletions

View File

@ -393,7 +393,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Make model path the same as the model name to make this consistent with the other loading method if it isn't a known model type # Make model path the same as the model name to make this consistent with the other loading method if it isn't a known model type
# This code is not just a workaround for below, it is also used to make the behavior consistent with other loading methods - Henk717 # This code is not just a workaround for below, it is also used to make the behavior consistent with other loading methods - Henk717
if(not vars.model in ["NeoCustom", "GPT2Custom"]): if(not vars.model in ["NeoCustom", "GPT2Custom"]):
vars.custmodpth = vars.model.replace('/', '_') vars.custmodpth = vars.model
# Get the model_type from the config or assume a model type if it isn't present # Get the model_type from the config or assume a model type if it isn't present
from transformers import AutoConfig from transformers import AutoConfig
try: try:
@ -1948,8 +1948,8 @@ def loadsettings():
# Allow the models to override some settings # Allow the models to override some settings
#==================================================================# #==================================================================#
def loadmodelsettings(): def loadmodelsettings():
if(path.exists(vars.custmodpth + "/config.json")): if(path.exists(vars.custmodpth.replace('/', '_') + "/config.json")):
model_config = open(vars.custmodpth + "/config.json", "r") model_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
js = json.load(model_config) js = json.load(model_config)
if("badwordsids" in js): if("badwordsids" in js):
vars.badwordsids = js["badwordsids"] vars.badwordsids = js["badwordsids"]