mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add API for generated tokens and output text
This commit is contained in:
106
aiserver.py
106
aiserver.py
@ -654,26 +654,37 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
assert scores.ndim == 2
|
assert scores.ndim == 2
|
||||||
|
assert input_ids.ndim == 2
|
||||||
self.regeneration_required = False
|
self.regeneration_required = False
|
||||||
self.halt = False
|
self.halt = False
|
||||||
|
|
||||||
scores_shape = scores.shape
|
scores_shape = scores.shape
|
||||||
scores_list = self.scores.tolist()
|
scores_list = self.scores.tolist()
|
||||||
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||||
for r, row in enumerate(scores_list):
|
for r, row in enumerate(scores_list):
|
||||||
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
||||||
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||||
|
|
||||||
|
if(vars.lua_koboldbridge.generated_cols != 0):
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
|
||||||
|
|
||||||
execute_genmod()
|
execute_genmod()
|
||||||
|
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
self.regeneration_required = True
|
||||||
|
|
||||||
|
if(not vars.lua_koboldbridge.generating):
|
||||||
|
self.halt = True
|
||||||
|
|
||||||
scores = torch.tensor(
|
scores = torch.tensor(
|
||||||
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
||||||
device=scores.device,
|
device=scores.device,
|
||||||
dtype=scores.dtype,
|
dtype=scores.dtype,
|
||||||
)
|
)
|
||||||
assert scores.shape == scores_shape
|
assert scores.shape == scores_shape
|
||||||
if(vars.lua_koboldbridge.regeneration_required):
|
|
||||||
vars.lua_koboldbridge.regeneration_required = False
|
|
||||||
self.regeneration_required = True
|
|
||||||
if(not vars.lua_koboldbridge.generating):
|
|
||||||
self.halt = True
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def new_get_logits_warper(
|
def new_get_logits_warper(
|
||||||
@ -1748,9 +1759,19 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
calcsubmit(data) # Run the first action through the generator
|
calcsubmit(data) # Run the first action through the generator
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
else:
|
else:
|
||||||
execute_outmod()
|
|
||||||
# Save this first action as the prompt
|
# Save this first action as the prompt
|
||||||
vars.prompt = data
|
vars.prompt = data
|
||||||
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||||
|
assert type(genout[-1]["generated_text"]) is str
|
||||||
|
if(len(genout) == 1):
|
||||||
|
genresult(genout[0]["generated_text"])
|
||||||
|
else:
|
||||||
|
genselect(genout)
|
||||||
refresh_story()
|
refresh_story()
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
@ -1776,6 +1797,16 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
else:
|
else:
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||||
|
assert type(genout[-1]["generated_text"]) is str
|
||||||
|
if(len(genout) == 1):
|
||||||
|
genresult(genout[0]["generated_text"])
|
||||||
|
else:
|
||||||
|
genselect(genout)
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@ -2076,6 +2107,12 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||||||
break
|
break
|
||||||
assert genout.ndim >= 2
|
assert genout.ndim >= 2
|
||||||
assert genout.shape[0] == vars.numseqs
|
assert genout.shape[0] == vars.numseqs
|
||||||
|
if(already_generated != vars.lua_koboldbridge.generated_cols):
|
||||||
|
raise RuntimeError("WI scanning error")
|
||||||
|
for r in range(vars.numseqs):
|
||||||
|
for c in range(already_generated):
|
||||||
|
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||||
|
genout[r][genout.shape[-1] - already_generated - c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||||
encoded = []
|
encoded = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||||
@ -2111,12 +2148,20 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||||||
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Need to manually strip and decode tokens if we're not using a pipeline
|
for i in range(vars.numseqs):
|
||||||
#already_generated = -(len(gen_in[0]) - len(tokens))
|
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = genout[i, -1].item()
|
||||||
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||||
|
assert type(genout[-1]["generated_text"]) is str
|
||||||
|
else:
|
||||||
|
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
@ -2217,8 +2262,17 @@ def sendtocolab(txt, min, max):
|
|||||||
else:
|
else:
|
||||||
genout = js["seqs"]
|
genout = js["seqs"]
|
||||||
|
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.outputs[i+1] = genout[i]
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
genout.append(vars.lua_koboldbridge.outputs[i+1])
|
||||||
|
assert type(genout[-1]) is str
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0])
|
genresult(genout[0])
|
||||||
else:
|
else:
|
||||||
@ -2297,10 +2351,20 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||||||
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
genout = [{"generated_text": txt} for txt in genout]
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
|
||||||
|
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||||
|
assert type(genout[-1]["generated_text"]) is str
|
||||||
|
else:
|
||||||
|
genout = [{"generated_text": tokenizer.decode(txt)} for txt in genout]
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
@ -2888,7 +2952,15 @@ def ikrequest(txt):
|
|||||||
# Deal with the response
|
# Deal with the response
|
||||||
if(req.status_code == 200):
|
if(req.status_code == 200):
|
||||||
genout = req.json()["data"]["text"]
|
genout = req.json()["data"]["text"]
|
||||||
|
|
||||||
|
vars.lua_koboldbridge.outputs[1] = genout
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = vars.lua_koboldbridge.outputs[1]
|
||||||
|
assert genout is str
|
||||||
|
|
||||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||||
vars.actions.append(genout)
|
vars.actions.append(genout)
|
||||||
update_story_chunk('last')
|
update_story_chunk('last')
|
||||||
@ -2939,7 +3011,15 @@ def oairequest(txt, min, max):
|
|||||||
# Deal with the response
|
# Deal with the response
|
||||||
if(req.status_code == 200):
|
if(req.status_code == 200):
|
||||||
genout = req.json()["choices"][0]["text"]
|
genout = req.json()["choices"][0]["text"]
|
||||||
|
|
||||||
|
vars.lua_koboldbridge.outputs[1] = genout
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = vars.lua_koboldbridge.outputs[1]
|
||||||
|
assert genout is str
|
||||||
|
|
||||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||||
vars.actions.append(genout)
|
vars.actions.append(genout)
|
||||||
update_story_chunk('last')
|
update_story_chunk('last')
|
||||||
|
169
bridge.lua
169
bridge.lua
@ -146,6 +146,13 @@ return function(_python, _bridged)
|
|||||||
---@field is_custommodel boolean
|
---@field is_custommodel boolean
|
||||||
---@field custmodpth string
|
---@field custmodpth string
|
||||||
---@field logits table<integer, table<integer, number>>
|
---@field logits table<integer, table<integer, number>>
|
||||||
|
---@field logits_rows integer
|
||||||
|
---@field logits_cols integer
|
||||||
|
---@field generated table<integer, table<integer, integer>>
|
||||||
|
---@field generated_rows integer
|
||||||
|
---@field generated_cols integer
|
||||||
|
---@field outputs table<integer, string>
|
||||||
|
---@field num_outputs integer
|
||||||
local kobold = setmetatable({}, metawrapper)
|
local kobold = setmetatable({}, metawrapper)
|
||||||
local KoboldLib_mt = setmetatable({}, metawrapper)
|
local KoboldLib_mt = setmetatable({}, metawrapper)
|
||||||
local KoboldLib_getters = setmetatable({}, metawrapper)
|
local KoboldLib_getters = setmetatable({}, metawrapper)
|
||||||
@ -203,10 +210,13 @@ return function(_python, _bridged)
|
|||||||
koboldbridge.userstate = "inmod"
|
koboldbridge.userstate = "inmod"
|
||||||
koboldbridge.logits = {}
|
koboldbridge.logits = {}
|
||||||
koboldbridge.vocab_size = 0
|
koboldbridge.vocab_size = 0
|
||||||
|
koboldbridge.generated = {}
|
||||||
|
koboldbridge.generated_cols = 0
|
||||||
|
koboldbridge.outputs = {}
|
||||||
|
|
||||||
---@return nil
|
---@return nil
|
||||||
local function maybe_require_regeneration()
|
local function maybe_require_regeneration()
|
||||||
if koboldbridge.userstate == "genmod" then
|
if koboldbridge.userstate == "genmod" or koboldbridge.userstate == "outmod" then
|
||||||
koboldbridge.regeneration_required = true
|
koboldbridge.regeneration_required = true
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -917,6 +927,9 @@ return function(_python, _bridged)
|
|||||||
---@param t KoboldLib
|
---@param t KoboldLib
|
||||||
---@return integer
|
---@return integer
|
||||||
function KoboldLib_getters.logits_rows(t)
|
function KoboldLib_getters.logits_rows(t)
|
||||||
|
if koboldbridge.userstate ~= "genmod" then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
local backend = kobold.modelbackend
|
local backend = kobold.modelbackend
|
||||||
if backend == "readonly" or backend == "api" then
|
if backend == "readonly" or backend == "api" then
|
||||||
return 0
|
return 0
|
||||||
@ -933,6 +946,9 @@ return function(_python, _bridged)
|
|||||||
---@param t KoboldLib
|
---@param t KoboldLib
|
||||||
---@return integer
|
---@return integer
|
||||||
function KoboldLib_getters.logits_cols(t)
|
function KoboldLib_getters.logits_cols(t)
|
||||||
|
if koboldbridge.userstate ~= "genmod" then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
local backend = kobold.modelbackend
|
local backend = kobold.modelbackend
|
||||||
if backend == "readonly" or backend == "api" then
|
if backend == "readonly" or backend == "api" then
|
||||||
return 0
|
return 0
|
||||||
@ -969,6 +985,119 @@ return function(_python, _bridged)
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
--==========================================================================
|
||||||
|
-- Userscript API: Generated Tokens
|
||||||
|
--==========================================================================
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return integer
|
||||||
|
function KoboldLib_getters.generated_rows(t)
|
||||||
|
local backend = kobold.modelbackend
|
||||||
|
if backend == "readonly" or backend == "api" then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
return kobold.settings.numseqs
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return integer
|
||||||
|
function KoboldLib_setters.generated_rows(t)
|
||||||
|
error("`KoboldLib.generated_rows` is a read-only attribute")
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return integer
|
||||||
|
function KoboldLib_getters.generated_cols(t)
|
||||||
|
if koboldbridge.userstate ~= "genmod" then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
local backend = kobold.modelbackend
|
||||||
|
if backend == "readonly" or backend == "api" then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
return math.tointeger(koboldbridge.generated_cols)
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return integer
|
||||||
|
function KoboldLib_setters.generated_cols(t)
|
||||||
|
error("`KoboldLib.generated_cols` is a read-only attribute")
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return table<integer, table<integer, integer>>
|
||||||
|
function KoboldLib_getters.generated(t)
|
||||||
|
if koboldbridge.userstate ~= "genmod" and koboldbridge.userstate ~= "outmod" then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
local backend = kobold.modelbackend
|
||||||
|
if backend == "readonly" or backend == "api" then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
return koboldbridge.generated
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@param v table<integer, table<integer, integer>>
|
||||||
|
function KoboldLib_setters.generated(t, v)
|
||||||
|
if koboldbridge.userstate ~= "genmod" then
|
||||||
|
error("Cannot write to `KoboldLib.generated` from outside of a generation modifier")
|
||||||
|
return
|
||||||
|
elseif type(v) ~= "table" then
|
||||||
|
error("`KoboldLib.generated` must be a 2D list (table) of integers; you attempted to set it to a " .. type(v))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
koboldbridge.generated = v
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
--==========================================================================
|
||||||
|
-- Userscript API: Output
|
||||||
|
--==========================================================================
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return integer
|
||||||
|
function KoboldLib_getters.num_outputs(t)
|
||||||
|
local backend = kobold.modelbackend
|
||||||
|
if backend == "readonly" then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
local model = kobold.model
|
||||||
|
if model == "OAI" or model == "InferKit" then
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return kobold.settings.numseqs
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return integer
|
||||||
|
function KoboldLib_setters.num_outputs(t)
|
||||||
|
error("`KoboldLib.num_outputs` is a read-only attribute")
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@return table<integer, string>
|
||||||
|
function KoboldLib_getters.outputs(t)
|
||||||
|
if koboldbridge.userstate ~= "outmod" then
|
||||||
|
return
|
||||||
|
end
|
||||||
|
return koboldbridge.outputs
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param t KoboldLib
|
||||||
|
---@param v table<integer, string>
|
||||||
|
function KoboldLib_setters.outputs(t, v)
|
||||||
|
if koboldbridge.userstate ~= "outmod" then
|
||||||
|
error("Cannot write to `KoboldLib.generated` from outside of an output modifier")
|
||||||
|
return
|
||||||
|
elseif type(v) ~= "table" then
|
||||||
|
error("`KoboldLib.generated` must be a list (table) of strings; you attempted to set it to a " .. type(v))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
koboldbridge.outputs = v
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
--==========================================================================
|
--==========================================================================
|
||||||
-- Userscript API: Utilities
|
-- Userscript API: Utilities
|
||||||
--==========================================================================
|
--==========================================================================
|
||||||
@ -1433,8 +1562,18 @@ return function(_python, _bridged)
|
|||||||
|
|
||||||
function koboldbridge.execute_inmod()
|
function koboldbridge.execute_inmod()
|
||||||
local r
|
local r
|
||||||
|
koboldbridge.regeneration_required = false
|
||||||
koboldbridge.generating = false
|
koboldbridge.generating = false
|
||||||
koboldbridge.userstate = "inmod"
|
koboldbridge.userstate = "inmod"
|
||||||
|
koboldbridge.generated_cols = 0
|
||||||
|
koboldbridge.generated = {}
|
||||||
|
for i = 1, kobold.settings.numseqs do
|
||||||
|
koboldbridge.generated[i] = {}
|
||||||
|
end
|
||||||
|
koboldbridge.outputs = {}
|
||||||
|
for i = 1, kobold.num_outputs do
|
||||||
|
koboldbridge.outputs[i] = {}
|
||||||
|
end
|
||||||
if koboldbridge.inmod ~= nil then
|
if koboldbridge.inmod ~= nil then
|
||||||
r = koboldbridge.inmod()
|
r = koboldbridge.inmod()
|
||||||
end
|
end
|
||||||
@ -1447,6 +1586,7 @@ return function(_python, _bridged)
|
|||||||
koboldbridge.generating = true
|
koboldbridge.generating = true
|
||||||
koboldbridge.userstate = "genmod"
|
koboldbridge.userstate = "genmod"
|
||||||
if koboldbridge.genmod ~= nil then
|
if koboldbridge.genmod ~= nil then
|
||||||
|
local _generated = deepcopy(koboldbridge.generated)
|
||||||
r = koboldbridge.genmod()
|
r = koboldbridge.genmod()
|
||||||
setmetatable(koboldbridge.logits, nil)
|
setmetatable(koboldbridge.logits, nil)
|
||||||
for kr, vr in old_next, koboldbridge.logits, nil do
|
for kr, vr in old_next, koboldbridge.logits, nil do
|
||||||
@ -1458,7 +1598,22 @@ return function(_python, _bridged)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
setmetatable(koboldbridge.generated, nil)
|
||||||
|
for kr, vr in old_next, koboldbridge.generated, nil do
|
||||||
|
setmetatable(vr, nil)
|
||||||
|
for kc, vc in old_next, vr, nil do
|
||||||
|
if math.tointeger(vc) == nil then
|
||||||
|
error("`kobold.generated` must be a 2D table of integers, but found a non-integer element at row " .. kr .. ", column " .. kc)
|
||||||
|
return r
|
||||||
|
end
|
||||||
|
vr[kc] = math.tointeger(vc)
|
||||||
|
if vr[kc] ~= _generated[kr][kc] then
|
||||||
|
maybe_require_regeneration()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
koboldbridge.generated_cols = koboldbridge.generated_cols + 1
|
||||||
return r
|
return r
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -1467,12 +1622,22 @@ return function(_python, _bridged)
|
|||||||
koboldbridge.generating = false
|
koboldbridge.generating = false
|
||||||
koboldbridge.userstate = "outmod"
|
koboldbridge.userstate = "outmod"
|
||||||
if koboldbridge.outmod ~= nil then
|
if koboldbridge.outmod ~= nil then
|
||||||
|
local _outputs = deepcopy(koboldbridge.outputs)
|
||||||
r = koboldbridge.outmod()
|
r = koboldbridge.outmod()
|
||||||
|
setmetatable(koboldbridge.outputs, nil)
|
||||||
|
for k, v in old_next, koboldbridge.outputs, nil do
|
||||||
|
if type(v) ~= "string" then
|
||||||
|
error("`kobold.outputs` must be a 1D list of strings, but found a non-string element at index " .. k)
|
||||||
|
return r
|
||||||
|
end
|
||||||
|
if v ~= _outputs[k] then
|
||||||
|
maybe_require_regeneration()
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if koboldbridge.resend_settings_required then
|
if koboldbridge.resend_settings_required then
|
||||||
bridged.resend_settings()
|
bridged.resend_settings()
|
||||||
end
|
end
|
||||||
koboldbridge.userstate = "inmod"
|
|
||||||
return r
|
return r
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -305,7 +305,7 @@ def infer(
|
|||||||
soft_embeddings=soft_embeddings,
|
soft_embeddings=soft_embeddings,
|
||||||
)[0]
|
)[0]
|
||||||
for o in output:
|
for o in output:
|
||||||
samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len]))
|
samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len])
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user