This commit is contained in:
ebolam
2022-09-22 12:06:51 -04:00
parent 48bc9c7317
commit 2f31958b5b
2 changed files with 12 additions and 12 deletions

View File

@@ -3219,7 +3219,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
# scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True,
#)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text()
txt, _, _, found_entries = koboldai_vars.calc_ai_text()
print("Using Alt Gen")
else:
#txt, _, _ = calcsubmitbudget(
@@ -3229,7 +3229,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
# anotetxt,
# actions,
#)
txt, _, _ = koboldai_vars.calc_ai_text(method=1)
txt, _, _, found_entries = koboldai_vars.calc_ai_text(method=1)
return utils.decodenewlines(tokenizer.decode(txt))
#==================================================================#
@@ -4832,11 +4832,11 @@ def calcsubmit(txt):
# For all transformers models
if(koboldai_vars.model != "InferKit"):
if koboldai_vars.alt_gen:
subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt)
subtxt, min, max, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt)
logger.debug("Using Alt Gen")
else:
#subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
subtxt, min, max, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
if(actionlen == 0):
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries)
@@ -4981,11 +4981,11 @@ def _generate(txt, minimum, maximum, found_entries):
#winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
found_entries[i].update(_found_entries)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
txt, _, _, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt)
logger.debug("Using Alt Gen")
else:
#txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
txt, _, _, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
@@ -5536,11 +5536,11 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
#winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
found_entries[i].update(_found_entries)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
txt, _, _, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt)
logger.debug("Using Alt Gen")
else:
#txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
txt, _, _, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
@@ -7192,9 +7192,9 @@ def show_folder_usersripts(data):
@app.route('/ai_text')
def ai_text():
start_time = time.time()
text = koboldai_vars.calc_ai_text()
print("Generating Game Text took {} seconds".format(time.time()-start_time))
return text
text = koboldai_vars.calc_ai_text(return_text=True)
logger.debug("Generating Game Text took {} seconds".format(time.time()-start_time))
return "{}\n\n\n{}".format(text, "Generating Game Text took {} seconds".format(time.time()-start_time))

View File

@@ -370,7 +370,7 @@ class koboldai_vars(object):
self.context = context
if return_text:
return text
return tokens, used_tokens, used_tokens+self.genamt
return tokens, used_tokens, used_tokens+self.genamt, used_world_info
def __setattr__(self, name, value):
if name[0] == "_" or name == "tokenizer":