Merge pull request #293 from one-some/ui2-contrastive-search

Basic Contrastive Search support for GPUs
This commit is contained in:
ebolam
2022-11-11 09:24:04 -05:00
committed by GitHub
6 changed files with 47 additions and 4 deletions

View File

@@ -90,6 +90,10 @@ global tpu_mtj_backend
if lupa.LUA_VERSION[:2] != (5, 4):
logger.error(f"Please install lupa==1.10. You have lupa {lupa.__version__}.")
if packaging.version.parse(transformers_version) < packaging.version.parse("4.24.0"):
logger.warning(f"Please upgrade to transformers 4.24.0 or later for Contrastive Search. You have transformers {transformers_version}.")
patch_causallm_patched = False
# Make sure tqdm progress bars display properly in Colab
@@ -1182,6 +1186,9 @@ def loadmodelsettings():
if("top_a" in js):
koboldai_vars.top_a = js["top_a"]
koboldai_vars.default_preset['top_a'] = js["top_a"]
if("penalty_alpha" in js):
koboldai_vars.penalty_alpha = js["penalty_alpha"]
koboldai_vars.default_preset['penalty_alpha'] = js["penalty_alpha"]
if("rep_pen" in js):
koboldai_vars.rep_pen = js["rep_pen"]
koboldai_vars.default_preset['rep_pen'] = js["rep_pen"]
@@ -2432,6 +2439,7 @@ def reset_model_settings():
koboldai_vars.top_a = 0.0 # Default generator top-a
koboldai_vars.tfs = 1.0 # Default generator tfs (tail-free sampling)
koboldai_vars.typical = 1.0 # Default generator typical sampling threshold
koboldai_vars.penalty_alpha = 0.0 # Default generator penalty_alpha (contrastive search)
koboldai_vars.numseqs = 1 # Number of sequences to ask the generator to create
koboldai_vars.generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
koboldai_vars.badwordsids = []
@@ -3613,6 +3621,7 @@ def lua_has_setting(setting):
"tfs",
"typical",
"topa",
"penalty_alpha",
"reppen",
"reppenslope",
"reppenrange",
@@ -3650,6 +3659,7 @@ def lua_get_setting(setting):
if(setting in ("settfs", "tfs")): return koboldai_vars.tfs
if(setting in ("settypical", "typical")): return koboldai_vars.typical
if(setting in ("settopa", "topa")): return koboldai_vars.top_a
if(setting in ("setpenaltyalpha", "penalty_alpha")): return koboldai_vars.penalty_alpha
if(setting in ("setreppen", "reppen")): return koboldai_vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return koboldai_vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return koboldai_vars.rep_pen_range
@@ -3688,6 +3698,7 @@ def lua_set_setting(setting, v):
if(setting in ("settfs", "tfs")): koboldai_vars.tfs = v
if(setting in ("settypical", "typical")): koboldai_vars.typical = v
if(setting in ("settopa", "topa")): koboldai_vars.top_a = v
if(setting in ("setpenaltyalpha", "penalty_alpha")): koboldai_vars.penalty_alpha = v
if(setting in ("setreppen", "reppen")): koboldai_vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): koboldai_vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): koboldai_vars.rep_pen_range = v
@@ -4090,6 +4101,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True, room="UI_1")
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setpenaltyalpha'):
koboldai_vars.penalty_alpha = float(msg['data'])
emit('from_server', {'cmd': 'setlabelpenaltyalpha', 'data': msg['data']}, broadcast=True, room="UI_1")
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'):
koboldai_vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True, room="UI_1")
@@ -5356,6 +5372,7 @@ class GenerationSettings:
"tfs",
"typical",
"top_a",
"penalty_alpha",
"rep_pen",
"rep_pen_slope",
"rep_pen_range",
@@ -5531,7 +5548,7 @@ def torch_raw_generate(
model.kai_scanner_excluded_world_info = model.kai_scanner_excluded_world_info or set()
logger.debug("torch_raw_generate: setup inference_config {}s".format(time.time()-start_time))
if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else:
@@ -5550,6 +5567,7 @@ def torch_raw_generate(
bad_words_ids=koboldai_vars.badwordsids,
use_cache=True,
num_return_sequences=batch_count,
penalty_alpha=koboldai_vars.penalty_alpha,
)
logger.debug("torch_raw_generate: run generator {}s".format(time.time()-start_time))
@@ -6450,6 +6468,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetfs', 'data': koboldai_vars.tfs}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatetypical', 'data': koboldai_vars.typical}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatetopa', 'data': koboldai_vars.top_a}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatepenaltyalpha', 'data': koboldai_vars.penalty_alpha}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatereppen', 'data': koboldai_vars.rep_pen}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatereppenslope', 'data': koboldai_vars.rep_pen_slope}, broadcast=True, room="UI_1")
emit('from_server', {'cmd': 'updatereppenrange', 'data': koboldai_vars.rep_pen_range}, broadcast=True, room="UI_1")
@@ -9064,7 +9083,7 @@ def UI_2_load_cookies():
def UI_2_save_new_preset(data):
preset = {}
#Data to get from current settings
for item in ["genamt", "rep_pen", "rep_pen_range", "rep_pen_slope", "sampler_order", "temp", "tfs", "top_a", "top_k", "top_p", "typical"]:
for item in ["genamt", "rep_pen", "rep_pen_range", "rep_pen_slope", "sampler_order", "temp", "tfs", "top_a", "top_k", "top_p", "typical", "penalty_alpha"]:
preset[item] = getattr(koboldai_vars, item)
#Data to get from UI
for item in ['preset', 'description']:
@@ -9862,6 +9881,7 @@ def _generate_text(body: GenerationInputSchema):
"top_k": ("koboldai_vars", "top_k", None),
"top_a": ("koboldai_vars", "top_a", None),
"top_p": ("koboldai_vars", "top_p", None),
"penalty_alpha": ("koboldai_vars", "penalty_alpha", None),
"tfs": ("koboldai_vars", "tfs", None),
"typical": ("koboldai_vars", "typical", None),
"temperature": ("koboldai_vars", "temp", None),