Merge pull request #26 from VE-FORBRYDERNE/sp-patch
More softprompting bug fixes
This commit is contained in:
commit
8ad3863854
15
aiserver.py
15
aiserver.py
|
@ -591,6 +591,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||||
|
|
||||||
|
def get_hidden_size_from_model(model):
|
||||||
|
try:
|
||||||
|
return int(model.transformer.hidden_size)
|
||||||
|
except:
|
||||||
|
return int(model.transformer.embed_dim)
|
||||||
|
|
||||||
# If custom GPT Neo model was chosen
|
# If custom GPT Neo model was chosen
|
||||||
if(vars.model == "NeoCustom"):
|
if(vars.model == "NeoCustom"):
|
||||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||||
|
@ -632,17 +638,20 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||||
if(vars.hascuda):
|
if(vars.hascuda):
|
||||||
if(vars.usegpu):
|
if(vars.usegpu):
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model, device=0)
|
model = AutoModelForCausalLM.from_pretrained(vars.model, device=0)
|
||||||
vars.modeldim = int(model.transformer.hidden_size)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
model = model.to(0)
|
model = model.to(0)
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||||
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
device_config(model)
|
device_config(model)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||||
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||||
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
|
|
||||||
# Suppress Author's Note by flagging square brackets (Old implementation)
|
# Suppress Author's Note by flagging square brackets (Old implementation)
|
||||||
|
@ -718,6 +727,8 @@ def do_connect():
|
||||||
emit('from_server', {'cmd': 'connected', 'smandelete': vars.smandelete, 'smanrename': vars.smanrename})
|
emit('from_server', {'cmd': 'connected', 'smandelete': vars.smandelete, 'smanrename': vars.smanrename})
|
||||||
if(vars.remote):
|
if(vars.remote):
|
||||||
emit('from_server', {'cmd': 'runs_remotely'})
|
emit('from_server', {'cmd': 'runs_remotely'})
|
||||||
|
if(vars.allowsp):
|
||||||
|
emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp})
|
||||||
|
|
||||||
if(not vars.gamestarted):
|
if(not vars.gamestarted):
|
||||||
setStartState()
|
setStartState()
|
||||||
|
@ -961,8 +972,6 @@ def get_message(msg):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def setStartState():
|
def setStartState():
|
||||||
txt = "<span>Welcome to <span class=\"color_cyan\">KoboldAI</span>! You are running <span class=\"color_green\">"+getmodelname()+"</span>.<br/>"
|
txt = "<span>Welcome to <span class=\"color_cyan\">KoboldAI</span>! You are running <span class=\"color_green\">"+getmodelname()+"</span>.<br/>"
|
||||||
if(vars.allowsp):
|
|
||||||
emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp}, broadcast=True)
|
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
txt = txt + "Please load a game or enter a prompt below to begin!</span>"
|
txt = txt + "Please load a game or enter a prompt below to begin!</span>"
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue