diff --git a/aiserver.py b/aiserver.py index 176d5acb..cd49b249 100644 --- a/aiserver.py +++ b/aiserver.py @@ -591,6 +591,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): return stopping_criteria transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria + def get_hidden_size_from_model(model): + try: + return int(model.transformer.hidden_size) + except: + return int(model.transformer.embed_dim) + # If custom GPT Neo model was chosen if(vars.model == "NeoCustom"): model_config = open(vars.custmodpth + "/config.json", "r") @@ -632,17 +638,20 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(vars.hascuda): if(vars.usegpu): model = AutoModelForCausalLM.from_pretrained(vars.model, device=0) - vars.modeldim = int(model.transformer.hidden_size) + vars.modeldim = get_hidden_size_from_model(model) model = model.to(0) generator = model.generate elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) model = AutoModelForCausalLM.from_pretrained(vars.model) + vars.modeldim = get_hidden_size_from_model(model) device_config(model) else: model = AutoModelForCausalLM.from_pretrained(vars.model) + vars.modeldim = get_hidden_size_from_model(model) generator = model.generate else: model = AutoModelForCausalLM.from_pretrained(vars.model) + vars.modeldim = get_hidden_size_from_model(model) generator = model.generate # Suppress Author's Note by flagging square brackets (Old implementation) @@ -718,6 +727,8 @@ def do_connect(): emit('from_server', {'cmd': 'connected', 'smandelete': vars.smandelete, 'smanrename': vars.smanrename}) if(vars.remote): emit('from_server', {'cmd': 'runs_remotely'}) + if(vars.allowsp): + emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp}) if(not vars.gamestarted): setStartState() @@ -961,8 +972,6 @@ def get_message(msg): #==================================================================# def setStartState(): txt = "Welcome to KoboldAI! You are running "+getmodelname()+".
" - if(vars.allowsp): - emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp}, broadcast=True) if(not vars.noai): txt = txt + "Please load a game or enter a prompt below to begin!
" else: