Remove tpumtjgenerate

Dead code as far as I can tell. Now handled in tpu_raw_generate
This commit is contained in:
somebody
2022-12-18 15:29:17 -06:00
parent 0983042953
commit bf82f257d1

View File

@@ -6349,145 +6349,6 @@ def pinsequence(n):
text = koboldai_vars.genseqs[int(n)]['generated_text'] text = koboldai_vars.genseqs[int(n)]['generated_text']
send_debug() send_debug()
#==================================================================#
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
if(koboldai_vars.full_determinism):
tpu_mtj_backend.set_rng_seed(koboldai_vars.seed)
koboldai_vars.generated_tkns = 0
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(koboldai_vars.numseqs))
if not koboldai_vars.quiet:
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
koboldai_vars._prompt = koboldai_vars.prompt
# Submit input text to generator
try:
soft_tokens = tpumtjgetsofttokens()
global past
socketio.start_background_task(copy_current_request_context(check_for_backend_compilation))
if(koboldai_vars.dynamicscan or (not koboldai_vars.nogenmod and koboldai_vars.has_genmod)):
context = np.tile(np.uint32(txt), (koboldai_vars.numseqs, 1))
past = np.empty((koboldai_vars.numseqs, 0), dtype=np.uint32)
while(True):
genout, n_generated, regeneration_required, halt = tpool.execute(
tpu_mtj_backend.infer_dynamic,
context,
gen_len = maximum-minimum+1,
numseqs=koboldai_vars.numseqs,
soft_embeddings=koboldai_vars.sp,
soft_tokens=soft_tokens,
excluded_world_info=found_entries,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(koboldai_vars.numseqs):
for c in range(koboldai_vars.lua_koboldbridge.generated_cols):
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
past[r, c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
if(koboldai_vars.abort or halt or not regeneration_required):
break
print("(regeneration triggered)")
encoded = []
for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(past[i]))
#winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
found_entries[i].update(_found_entries)
#txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
txt, _, _, found_entries = koboldai_vars.calc_ai_text(submitted_text=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
else:
genout = tpool.execute(
tpu_mtj_backend.infer_static,
np.uint32(txt),
gen_len = maximum-minimum+1,
temp=koboldai_vars.temp,
top_p=koboldai_vars.top_p,
top_k=koboldai_vars.top_k,
tfs=koboldai_vars.tfs,
typical=koboldai_vars.typical,
top_a=koboldai_vars.top_a,
numseqs=koboldai_vars.numseqs,
repetition_penalty=koboldai_vars.rep_pen,
rpslope=koboldai_vars.rep_pen_slope,
rprange=koboldai_vars.rep_pen_range,
soft_embeddings=koboldai_vars.sp,
soft_tokens=soft_tokens,
sampler_order=koboldai_vars.sampler_order,
)
past = genout
for i in range(koboldai_vars.numseqs):
koboldai_vars.lua_koboldbridge.generated[i+1] = koboldai_vars.lua_state.table(*genout[i].tolist())
koboldai_vars.lua_koboldbridge.generated_cols = koboldai_vars.generated_tkns = genout[0].shape[-1]
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
koboldai_vars.lua_koboldbridge.obliterate_multiverse()
koboldai_vars.lua_running = False
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True, room="UI_1")
sendUSStatItems()
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
socketio.emit("error", str(e), broadcast=True, room="UI_2")
else:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occurred during generator call; please check console.'}, broadcast=True, room="UI_1")
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
socketio.emit("error", str(e), broadcast=True, room="UI_2")
set_aibusy(0)
return
for i in range(koboldai_vars.numseqs):
koboldai_vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(past[i]))
genout = past
execute_outmod()
if(koboldai_vars.lua_koboldbridge.regeneration_required):
koboldai_vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(koboldai_vars.numseqs):
genout.append({"generated_text": koboldai_vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
else:
genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(txt))} for txt in genout]
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
if(len(koboldai_vars.actions.get_current_options()) == 1):
genresult(koboldai_vars.actions.get_current_options()[0]['text'])
else:
if(koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
else:
genselect([{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()])
set_aibusy(0)
#==================================================================# #==================================================================#
# Replaces returns and newlines with HTML breaks # Replaces returns and newlines with HTML breaks
#==================================================================# #==================================================================#