diff --git a/aiserver.py b/aiserver.py index e3f7a0f6..9be01f10 100644 --- a/aiserver.py +++ b/aiserver.py @@ -654,26 +654,37 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: assert scores.ndim == 2 + assert input_ids.ndim == 2 self.regeneration_required = False self.halt = False + scores_shape = scores.shape scores_list = self.scores.tolist() vars.lua_koboldbridge.logits = vars.lua_state.table() for r, row in enumerate(scores_list): vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) vars.lua_koboldbridge.vocab_size = scores_shape[-1] + + if(vars.lua_koboldbridge.generated_cols != 0): + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item() + execute_genmod() + + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + self.regeneration_required = True + + if(not vars.lua_koboldbridge.generating): + self.halt = True + scores = torch.tensor( tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), device=scores.device, dtype=scores.dtype, ) assert scores.shape == scores_shape - if(vars.lua_koboldbridge.regeneration_required): - vars.lua_koboldbridge.regeneration_required = False - self.regeneration_required = True - if(not vars.lua_koboldbridge.generating): - self.halt = True + return scores def new_get_logits_warper( @@ -1748,9 +1759,19 @@ def actionsubmit(data, actionmode=0, force_submit=False): calcsubmit(data) # Run the first action through the generator emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) else: - execute_outmod() # Save this first action as the prompt vars.prompt = data + execute_outmod() + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + genout = [] + for i in range(vars.numseqs): + genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) + assert type(genout[-1]["generated_text"]) is str + if(len(genout) == 1): + genresult(genout[0]["generated_text"]) + else: + genselect(genout) refresh_story() set_aibusy(0) emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) @@ -1776,6 +1797,16 @@ def actionsubmit(data, actionmode=0, force_submit=False): else: execute_outmod() set_aibusy(0) + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + genout = [] + for i in range(vars.numseqs): + genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) + assert type(genout[-1]["generated_text"]) is str + if(len(genout) == 1): + genresult(genout[0]["generated_text"]) + else: + genselect(genout) emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) #==================================================================# @@ -2076,6 +2107,12 @@ def generate(txt, minimum, maximum, found_entries=None): break assert genout.ndim >= 2 assert genout.shape[0] == vars.numseqs + if(already_generated != vars.lua_koboldbridge.generated_cols): + raise RuntimeError("WI scanning error") + for r in range(vars.numseqs): + for c in range(already_generated): + assert vars.lua_koboldbridge.generated[r+1][c+1] is not None + genout[r][genout.shape[-1] - already_generated - c] = vars.lua_koboldbridge.generated[r+1][c+1] encoded = [] for i in range(vars.numseqs): txt = tokenizer.decode(genout[i, -already_generated:]) @@ -2111,12 +2148,20 @@ def generate(txt, minimum, maximum, found_entries=None): print("{0}{1}{2}".format(colors.RED, e, colors.END)) set_aibusy(0) return - - # Need to manually strip and decode tokens if we're not using a pipeline - #already_generated = -(len(gen_in[0]) - len(tokens)) - genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout] + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = genout[i, -1].item() + vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:]) execute_outmod() + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + genout = [] + for i in range(vars.numseqs): + genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) + assert type(genout[-1]["generated_text"]) is str + else: + genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout] if(len(genout) == 1): genresult(genout[0]["generated_text"]) @@ -2217,8 +2262,17 @@ def sendtocolab(txt, min, max): else: genout = js["seqs"] + for i in range(vars.numseqs): + vars.lua_koboldbridge.outputs[i+1] = genout[i] + execute_outmod() - + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + genout = [] + for i in range(vars.numseqs): + genout.append(vars.lua_koboldbridge.outputs[i+1]) + assert type(genout[-1]) is str + if(len(genout) == 1): genresult(genout[0]) else: @@ -2297,10 +2351,20 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): print("{0}{1}{2}".format(colors.RED, e, colors.END)) set_aibusy(0) return - - genout = [{"generated_text": txt} for txt in genout] + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) + vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i]) execute_outmod() + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + genout = [] + for i in range(vars.numseqs): + genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) + assert type(genout[-1]["generated_text"]) is str + else: + genout = [{"generated_text": tokenizer.decode(txt)} for txt in genout] if(len(genout) == 1): genresult(genout[0]["generated_text"]) @@ -2888,7 +2952,15 @@ def ikrequest(txt): # Deal with the response if(req.status_code == 200): genout = req.json()["data"]["text"] + + vars.lua_koboldbridge.outputs[1] = genout + execute_outmod() + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + genout = vars.lua_koboldbridge.outputs[1] + assert genout is str + print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) vars.actions.append(genout) update_story_chunk('last') @@ -2939,7 +3011,15 @@ def oairequest(txt, min, max): # Deal with the response if(req.status_code == 200): genout = req.json()["choices"][0]["text"] + + vars.lua_koboldbridge.outputs[1] = genout + execute_outmod() + if(vars.lua_koboldbridge.regeneration_required): + vars.lua_koboldbridge.regeneration_required = False + genout = vars.lua_koboldbridge.outputs[1] + assert genout is str + print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) vars.actions.append(genout) update_story_chunk('last') diff --git a/bridge.lua b/bridge.lua index 13be25d2..d8741174 100644 --- a/bridge.lua +++ b/bridge.lua @@ -146,6 +146,13 @@ return function(_python, _bridged) ---@field is_custommodel boolean ---@field custmodpth string ---@field logits table> + ---@field logits_rows integer + ---@field logits_cols integer + ---@field generated table> + ---@field generated_rows integer + ---@field generated_cols integer + ---@field outputs table + ---@field num_outputs integer local kobold = setmetatable({}, metawrapper) local KoboldLib_mt = setmetatable({}, metawrapper) local KoboldLib_getters = setmetatable({}, metawrapper) @@ -203,10 +210,13 @@ return function(_python, _bridged) koboldbridge.userstate = "inmod" koboldbridge.logits = {} koboldbridge.vocab_size = 0 + koboldbridge.generated = {} + koboldbridge.generated_cols = 0 + koboldbridge.outputs = {} ---@return nil local function maybe_require_regeneration() - if koboldbridge.userstate == "genmod" then + if koboldbridge.userstate == "genmod" or koboldbridge.userstate == "outmod" then koboldbridge.regeneration_required = true end end @@ -917,6 +927,9 @@ return function(_python, _bridged) ---@param t KoboldLib ---@return integer function KoboldLib_getters.logits_rows(t) + if koboldbridge.userstate ~= "genmod" then + return 0 + end local backend = kobold.modelbackend if backend == "readonly" or backend == "api" then return 0 @@ -933,6 +946,9 @@ return function(_python, _bridged) ---@param t KoboldLib ---@return integer function KoboldLib_getters.logits_cols(t) + if koboldbridge.userstate ~= "genmod" then + return 0 + end local backend = kobold.modelbackend if backend == "readonly" or backend == "api" then return 0 @@ -969,6 +985,119 @@ return function(_python, _bridged) end + --========================================================================== + -- Userscript API: Generated Tokens + --========================================================================== + + ---@param t KoboldLib + ---@return integer + function KoboldLib_getters.generated_rows(t) + local backend = kobold.modelbackend + if backend == "readonly" or backend == "api" then + return 0 + end + return kobold.settings.numseqs + end + + ---@param t KoboldLib + ---@return integer + function KoboldLib_setters.generated_rows(t) + error("`KoboldLib.generated_rows` is a read-only attribute") + end + + ---@param t KoboldLib + ---@return integer + function KoboldLib_getters.generated_cols(t) + if koboldbridge.userstate ~= "genmod" then + return 0 + end + local backend = kobold.modelbackend + if backend == "readonly" or backend == "api" then + return 0 + end + return math.tointeger(koboldbridge.generated_cols) + end + + ---@param t KoboldLib + ---@return integer + function KoboldLib_setters.generated_cols(t) + error("`KoboldLib.generated_cols` is a read-only attribute") + end + + ---@param t KoboldLib + ---@return table> + function KoboldLib_getters.generated(t) + if koboldbridge.userstate ~= "genmod" and koboldbridge.userstate ~= "outmod" then + return + end + local backend = kobold.modelbackend + if backend == "readonly" or backend == "api" then + return + end + return koboldbridge.generated + end + + ---@param t KoboldLib + ---@param v table> + function KoboldLib_setters.generated(t, v) + if koboldbridge.userstate ~= "genmod" then + error("Cannot write to `KoboldLib.generated` from outside of a generation modifier") + return + elseif type(v) ~= "table" then + error("`KoboldLib.generated` must be a 2D list (table) of integers; you attempted to set it to a " .. type(v)) + return + end + koboldbridge.generated = v + end + + + --========================================================================== + -- Userscript API: Output + --========================================================================== + + ---@param t KoboldLib + ---@return integer + function KoboldLib_getters.num_outputs(t) + local backend = kobold.modelbackend + if backend == "readonly" then + return 0 + end + local model = kobold.model + if model == "OAI" or model == "InferKit" then + return 1 + end + return kobold.settings.numseqs + end + + ---@param t KoboldLib + ---@return integer + function KoboldLib_setters.num_outputs(t) + error("`KoboldLib.num_outputs` is a read-only attribute") + end + + ---@param t KoboldLib + ---@return table + function KoboldLib_getters.outputs(t) + if koboldbridge.userstate ~= "outmod" then + return + end + return koboldbridge.outputs + end + + ---@param t KoboldLib + ---@param v table + function KoboldLib_setters.outputs(t, v) + if koboldbridge.userstate ~= "outmod" then + error("Cannot write to `KoboldLib.generated` from outside of an output modifier") + return + elseif type(v) ~= "table" then + error("`KoboldLib.generated` must be a list (table) of strings; you attempted to set it to a " .. type(v)) + return + end + koboldbridge.outputs = v + end + + --========================================================================== -- Userscript API: Utilities --========================================================================== @@ -1433,8 +1562,18 @@ return function(_python, _bridged) function koboldbridge.execute_inmod() local r + koboldbridge.regeneration_required = false koboldbridge.generating = false koboldbridge.userstate = "inmod" + koboldbridge.generated_cols = 0 + koboldbridge.generated = {} + for i = 1, kobold.settings.numseqs do + koboldbridge.generated[i] = {} + end + koboldbridge.outputs = {} + for i = 1, kobold.num_outputs do + koboldbridge.outputs[i] = {} + end if koboldbridge.inmod ~= nil then r = koboldbridge.inmod() end @@ -1447,6 +1586,7 @@ return function(_python, _bridged) koboldbridge.generating = true koboldbridge.userstate = "genmod" if koboldbridge.genmod ~= nil then + local _generated = deepcopy(koboldbridge.generated) r = koboldbridge.genmod() setmetatable(koboldbridge.logits, nil) for kr, vr in old_next, koboldbridge.logits, nil do @@ -1458,7 +1598,22 @@ return function(_python, _bridged) end end end + setmetatable(koboldbridge.generated, nil) + for kr, vr in old_next, koboldbridge.generated, nil do + setmetatable(vr, nil) + for kc, vc in old_next, vr, nil do + if math.tointeger(vc) == nil then + error("`kobold.generated` must be a 2D table of integers, but found a non-integer element at row " .. kr .. ", column " .. kc) + return r + end + vr[kc] = math.tointeger(vc) + if vr[kc] ~= _generated[kr][kc] then + maybe_require_regeneration() + end + end + end end + koboldbridge.generated_cols = koboldbridge.generated_cols + 1 return r end @@ -1467,12 +1622,22 @@ return function(_python, _bridged) koboldbridge.generating = false koboldbridge.userstate = "outmod" if koboldbridge.outmod ~= nil then + local _outputs = deepcopy(koboldbridge.outputs) r = koboldbridge.outmod() + setmetatable(koboldbridge.outputs, nil) + for k, v in old_next, koboldbridge.outputs, nil do + if type(v) ~= "string" then + error("`kobold.outputs` must be a 1D list of strings, but found a non-string element at index " .. k) + return r + end + if v ~= _outputs[k] then + maybe_require_regeneration() + end + end end if koboldbridge.resend_settings_required then bridged.resend_settings() end - koboldbridge.userstate = "inmod" return r end diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index cbf6b499..ffb7b3e3 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -305,7 +305,7 @@ def infer( soft_embeddings=soft_embeddings, )[0] for o in output: - samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len])) + samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) return samples