Fix Lua regeneration system

This commit is contained in:
Gnome Ann 2021-12-13 19:17:18 -05:00
parent 462040ed6f
commit e5bb20cc8f
1 changed files with 6 additions and 15 deletions

View File

@ -656,8 +656,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
class LuaLogitsWarper(LogitsWarper):
def __init__(self):
self.regeneration_required = False
self.halt = False
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@ -673,19 +671,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
if(vars.lua_koboldbridge.generated_cols != 0):
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
execute_genmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True
if(not vars.lua_koboldbridge.generating):
self.halt = True
scores = torch.tensor(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
device=scores.device,
@ -758,8 +745,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
) -> bool:
assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
self.regeneration_required = vars.lua_warper.regeneration_required
self.halt = vars.lua_warper.halt
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
self.halt = not vars.lua_koboldbridge.generating
vars.lua_koboldbridge.regeneration_required = False
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
if(not vars.dynamicscan):
return self.regeneration_required or self.halt