mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #297 from one-some/ui2-contrastive-search
Revert "Basic Contrastive Search support"
This commit is contained in:
22
aiserver.py
22
aiserver.py
@@ -90,10 +90,6 @@ global tpu_mtj_backend
|
|||||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||||
logger.error(f"Please install lupa==1.10. You have lupa {lupa.__version__}.")
|
logger.error(f"Please install lupa==1.10. You have lupa {lupa.__version__}.")
|
||||||
|
|
||||||
if packaging.version.parse(transformers_version) < packaging.version.parse("4.24.0"):
|
|
||||||
logger.warning(f"Please upgrade to transformers 4.24.0 or later for Contrastive Search. You have transformers {transformers_version}.")
|
|
||||||
|
|
||||||
|
|
||||||
patch_causallm_patched = False
|
patch_causallm_patched = False
|
||||||
|
|
||||||
# Make sure tqdm progress bars display properly in Colab
|
# Make sure tqdm progress bars display properly in Colab
|
||||||
@@ -1186,9 +1182,6 @@ def loadmodelsettings():
|
|||||||
if("top_a" in js):
|
if("top_a" in js):
|
||||||
koboldai_vars.top_a = js["top_a"]
|
koboldai_vars.top_a = js["top_a"]
|
||||||
koboldai_vars.default_preset['top_a'] = js["top_a"]
|
koboldai_vars.default_preset['top_a'] = js["top_a"]
|
||||||
if("penalty_alpha" in js):
|
|
||||||
koboldai_vars.penalty_alpha = js["penalty_alpha"]
|
|
||||||
koboldai_vars.default_preset['penalty_alpha'] = js["penalty_alpha"]
|
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
koboldai_vars.rep_pen = js["rep_pen"]
|
koboldai_vars.rep_pen = js["rep_pen"]
|
||||||
koboldai_vars.default_preset['rep_pen'] = js["rep_pen"]
|
koboldai_vars.default_preset['rep_pen'] = js["rep_pen"]
|
||||||
@@ -2440,7 +2433,6 @@ def reset_model_settings():
|
|||||||
koboldai_vars.top_a = 0.0 # Default generator top-a
|
koboldai_vars.top_a = 0.0 # Default generator top-a
|
||||||
koboldai_vars.tfs = 1.0 # Default generator tfs (tail-free sampling)
|
koboldai_vars.tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||||
koboldai_vars.typical = 1.0 # Default generator typical sampling threshold
|
koboldai_vars.typical = 1.0 # Default generator typical sampling threshold
|
||||||
koboldai_vars.penalty_alpha = 0.0 # Default generator penalty_alpha (contrastive search)
|
|
||||||
koboldai_vars.numseqs = 1 # Number of sequences to ask the generator to create
|
koboldai_vars.numseqs = 1 # Number of sequences to ask the generator to create
|
||||||
koboldai_vars.generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
|
koboldai_vars.generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
|
||||||
koboldai_vars.badwordsids = []
|
koboldai_vars.badwordsids = []
|
||||||
@@ -3622,7 +3614,6 @@ def lua_has_setting(setting):
|
|||||||
"tfs",
|
"tfs",
|
||||||
"typical",
|
"typical",
|
||||||
"topa",
|
"topa",
|
||||||
"penalty_alpha",
|
|
||||||
"reppen",
|
"reppen",
|
||||||
"reppenslope",
|
"reppenslope",
|
||||||
"reppenrange",
|
"reppenrange",
|
||||||
@@ -3660,7 +3651,6 @@ def lua_get_setting(setting):
|
|||||||
if(setting in ("settfs", "tfs")): return koboldai_vars.tfs
|
if(setting in ("settfs", "tfs")): return koboldai_vars.tfs
|
||||||
if(setting in ("settypical", "typical")): return koboldai_vars.typical
|
if(setting in ("settypical", "typical")): return koboldai_vars.typical
|
||||||
if(setting in ("settopa", "topa")): return koboldai_vars.top_a
|
if(setting in ("settopa", "topa")): return koboldai_vars.top_a
|
||||||
if(setting in ("setpenaltyalpha", "penalty_alpha")): return koboldai_vars.penalty_alpha
|
|
||||||
if(setting in ("setreppen", "reppen")): return koboldai_vars.rep_pen
|
if(setting in ("setreppen", "reppen")): return koboldai_vars.rep_pen
|
||||||
if(setting in ("setreppenslope", "reppenslope")): return koboldai_vars.rep_pen_slope
|
if(setting in ("setreppenslope", "reppenslope")): return koboldai_vars.rep_pen_slope
|
||||||
if(setting in ("setreppenrange", "reppenrange")): return koboldai_vars.rep_pen_range
|
if(setting in ("setreppenrange", "reppenrange")): return koboldai_vars.rep_pen_range
|
||||||
@@ -3699,7 +3689,6 @@ def lua_set_setting(setting, v):
|
|||||||
if(setting in ("settfs", "tfs")): koboldai_vars.tfs = v
|
if(setting in ("settfs", "tfs")): koboldai_vars.tfs = v
|
||||||
if(setting in ("settypical", "typical")): koboldai_vars.typical = v
|
if(setting in ("settypical", "typical")): koboldai_vars.typical = v
|
||||||
if(setting in ("settopa", "topa")): koboldai_vars.top_a = v
|
if(setting in ("settopa", "topa")): koboldai_vars.top_a = v
|
||||||
if(setting in ("setpenaltyalpha", "penalty_alpha")): koboldai_vars.penalty_alpha = v
|
|
||||||
if(setting in ("setreppen", "reppen")): koboldai_vars.rep_pen = v
|
if(setting in ("setreppen", "reppen")): koboldai_vars.rep_pen = v
|
||||||
if(setting in ("setreppenslope", "reppenslope")): koboldai_vars.rep_pen_slope = v
|
if(setting in ("setreppenslope", "reppenslope")): koboldai_vars.rep_pen_slope = v
|
||||||
if(setting in ("setreppenrange", "reppenrange")): koboldai_vars.rep_pen_range = v
|
if(setting in ("setreppenrange", "reppenrange")): koboldai_vars.rep_pen_range = v
|
||||||
@@ -4102,11 +4091,6 @@ def get_message(msg):
|
|||||||
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True, room="UI_1")
|
||||||
settingschanged()
|
settingschanged()
|
||||||
refresh_settings()
|
refresh_settings()
|
||||||
elif(msg['cmd'] == 'setpenaltyalpha'):
|
|
||||||
koboldai_vars.penalty_alpha = float(msg['data'])
|
|
||||||
emit('from_server', {'cmd': 'setlabelpenaltyalpha', 'data': msg['data']}, broadcast=True, room="UI_1")
|
|
||||||
settingschanged()
|
|
||||||
refresh_settings()
|
|
||||||
elif(msg['cmd'] == 'setreppen'):
|
elif(msg['cmd'] == 'setreppen'):
|
||||||
koboldai_vars.rep_pen = float(msg['data'])
|
koboldai_vars.rep_pen = float(msg['data'])
|
||||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True, room="UI_1")
|
||||||
@@ -5373,7 +5357,6 @@ class GenerationSettings:
|
|||||||
"tfs",
|
"tfs",
|
||||||
"typical",
|
"typical",
|
||||||
"top_a",
|
"top_a",
|
||||||
"penalty_alpha",
|
|
||||||
"rep_pen",
|
"rep_pen",
|
||||||
"rep_pen_slope",
|
"rep_pen_slope",
|
||||||
"rep_pen_range",
|
"rep_pen_range",
|
||||||
@@ -5568,7 +5551,6 @@ def torch_raw_generate(
|
|||||||
bad_words_ids=koboldai_vars.badwordsids,
|
bad_words_ids=koboldai_vars.badwordsids,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_return_sequences=batch_count,
|
num_return_sequences=batch_count,
|
||||||
penalty_alpha=koboldai_vars.penalty_alpha,
|
|
||||||
)
|
)
|
||||||
logger.debug("torch_raw_generate: run generator {}s".format(time.time()-start_time))
|
logger.debug("torch_raw_generate: run generator {}s".format(time.time()-start_time))
|
||||||
|
|
||||||
@@ -6469,7 +6451,6 @@ def refresh_settings():
|
|||||||
emit('from_server', {'cmd': 'updatetfs', 'data': koboldai_vars.tfs}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'updatetfs', 'data': koboldai_vars.tfs}, broadcast=True, room="UI_1")
|
||||||
emit('from_server', {'cmd': 'updatetypical', 'data': koboldai_vars.typical}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'updatetypical', 'data': koboldai_vars.typical}, broadcast=True, room="UI_1")
|
||||||
emit('from_server', {'cmd': 'updatetopa', 'data': koboldai_vars.top_a}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'updatetopa', 'data': koboldai_vars.top_a}, broadcast=True, room="UI_1")
|
||||||
emit('from_server', {'cmd': 'updatepenaltyalpha', 'data': koboldai_vars.penalty_alpha}, broadcast=True, room="UI_1")
|
|
||||||
emit('from_server', {'cmd': 'updatereppen', 'data': koboldai_vars.rep_pen}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'updatereppen', 'data': koboldai_vars.rep_pen}, broadcast=True, room="UI_1")
|
||||||
emit('from_server', {'cmd': 'updatereppenslope', 'data': koboldai_vars.rep_pen_slope}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'updatereppenslope', 'data': koboldai_vars.rep_pen_slope}, broadcast=True, room="UI_1")
|
||||||
emit('from_server', {'cmd': 'updatereppenrange', 'data': koboldai_vars.rep_pen_range}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'updatereppenrange', 'data': koboldai_vars.rep_pen_range}, broadcast=True, room="UI_1")
|
||||||
@@ -9084,7 +9065,7 @@ def UI_2_load_cookies():
|
|||||||
def UI_2_save_new_preset(data):
|
def UI_2_save_new_preset(data):
|
||||||
preset = {}
|
preset = {}
|
||||||
#Data to get from current settings
|
#Data to get from current settings
|
||||||
for item in ["genamt", "rep_pen", "rep_pen_range", "rep_pen_slope", "sampler_order", "temp", "tfs", "top_a", "top_k", "top_p", "typical", "penalty_alpha"]:
|
for item in ["genamt", "rep_pen", "rep_pen_range", "rep_pen_slope", "sampler_order", "temp", "tfs", "top_a", "top_k", "top_p", "typical"]:
|
||||||
preset[item] = getattr(koboldai_vars, item)
|
preset[item] = getattr(koboldai_vars, item)
|
||||||
#Data to get from UI
|
#Data to get from UI
|
||||||
for item in ['preset', 'description']:
|
for item in ['preset', 'description']:
|
||||||
@@ -9895,7 +9876,6 @@ def _generate_text(body: GenerationInputSchema):
|
|||||||
"top_k": ("koboldai_vars", "top_k", None),
|
"top_k": ("koboldai_vars", "top_k", None),
|
||||||
"top_a": ("koboldai_vars", "top_a", None),
|
"top_a": ("koboldai_vars", "top_a", None),
|
||||||
"top_p": ("koboldai_vars", "top_p", None),
|
"top_p": ("koboldai_vars", "top_p", None),
|
||||||
"penalty_alpha": ("koboldai_vars", "penalty_alpha", None),
|
|
||||||
"tfs": ("koboldai_vars", "tfs", None),
|
"tfs": ("koboldai_vars", "tfs", None),
|
||||||
"typical": ("koboldai_vars", "typical", None),
|
"typical": ("koboldai_vars", "typical", None),
|
||||||
"temperature": ("koboldai_vars", "temp", None),
|
"temperature": ("koboldai_vars", "temp", None),
|
||||||
|
@@ -107,22 +107,6 @@ gensettingstf = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"UI_V2_Only": True,
|
"UI_V2_Only": True,
|
||||||
"uitype": "slider",
|
|
||||||
"unit": "float",
|
|
||||||
"label": "Penalty Alpha",
|
|
||||||
"id": "setpenaltyalpha",
|
|
||||||
"min": 0.0,
|
|
||||||
"max": 1.0,
|
|
||||||
"step": 0.01,
|
|
||||||
"default": 0.0,
|
|
||||||
"tooltip": "Alternative search method. Encourages diversity of output embeddings. To be used with Top-K. Does not work on TPU! (Put this value on 0 to disable its effect)",
|
|
||||||
"menu_path": "Settings",
|
|
||||||
"sub_path": "Sampling",
|
|
||||||
"classname": "model",
|
|
||||||
"name": "penalty_alpha",
|
|
||||||
"extra_classes": "var_sync_alt_system_use_colab_tpu"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"uitype": "slider",
|
"uitype": "slider",
|
||||||
"unit": "float",
|
"unit": "float",
|
||||||
"label": "Repetition Penalty",
|
"label": "Repetition Penalty",
|
||||||
|
@@ -622,7 +622,6 @@ class model_settings(settings):
|
|||||||
self.top_a = 0.0 # Default generator top-a
|
self.top_a = 0.0 # Default generator top-a
|
||||||
self.tfs = 1.0 # Default generator tfs (tail-free sampling)
|
self.tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||||
self.typical = 1.0 # Default generator typical sampling threshold
|
self.typical = 1.0 # Default generator typical sampling threshold
|
||||||
self.penalty_alpha = 0.0 # Default generator penalty_alpha (contrastive search)
|
|
||||||
self.numseqs = 1 # Number of sequences to ask the generator to create
|
self.numseqs = 1 # Number of sequences to ask the generator to create
|
||||||
self.badwordsids = []
|
self.badwordsids = []
|
||||||
self.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
|
self.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
transformers>=4.24.0
|
transformers>=4.20.1
|
||||||
Flask
|
Flask
|
||||||
Flask-SocketIO
|
Flask-SocketIO
|
||||||
requests
|
requests
|
||||||
|
@@ -5,7 +5,7 @@ requests
|
|||||||
dm-haiku == 0.0.5
|
dm-haiku == 0.0.5
|
||||||
jax == 0.2.21
|
jax == 0.2.21
|
||||||
jaxlib >= 0.1.69, <= 0.3.7
|
jaxlib >= 0.1.69, <= 0.3.7
|
||||||
transformers >=4.24.0
|
transformers >=4.20.1
|
||||||
progressbar2
|
progressbar2
|
||||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||||
flask
|
flask
|
||||||
|
@@ -2619,10 +2619,6 @@ $(document).ready(function(){
|
|||||||
// Send current top a value to input
|
// Send current top a value to input
|
||||||
$("#settopacur").val(msg.data);
|
$("#settopacur").val(msg.data);
|
||||||
$("#settopa").val(parseFloat(msg.data)).trigger("change");
|
$("#settopa").val(parseFloat(msg.data)).trigger("change");
|
||||||
} else if(msg.cmd == "updatepenaltyalpha") {
|
|
||||||
// Send current top p value to input
|
|
||||||
$("#setpenaltyalphacur").val(msg.data);
|
|
||||||
$("#setpenaltyalpha").val(parseFloat(msg.data)).trigger("change");
|
|
||||||
} else if(msg.cmd == "updatereppen") {
|
} else if(msg.cmd == "updatereppen") {
|
||||||
// Send current rep pen value to input
|
// Send current rep pen value to input
|
||||||
$("#setreppencur").val(msg.data);
|
$("#setreppencur").val(msg.data);
|
||||||
@@ -2653,9 +2649,6 @@ $(document).ready(function(){
|
|||||||
} else if(msg.cmd == "setlabeltopp") {
|
} else if(msg.cmd == "setlabeltopp") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#settoppcur").val(msg.data);
|
$("#settoppcur").val(msg.data);
|
||||||
} else if(msg.cmd == "setlabelpenaltyalpha") {
|
|
||||||
// Update setting label with value from server
|
|
||||||
$("#setpenaltyalphacur").val(msg.data);
|
|
||||||
} else if(msg.cmd == "setlabeltopk") {
|
} else if(msg.cmd == "setlabeltopk") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#settopkcur").val(msg.data);
|
$("#settopkcur").val(msg.data);
|
||||||
|
Reference in New Issue
Block a user