Fix Lua regeneration system
This commit is contained in:
parent
462040ed6f
commit
e5bb20cc8f
21
aiserver.py
21
aiserver.py
|
@ -656,8 +656,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
class LuaLogitsWarper(LogitsWarper):
|
class LuaLogitsWarper(LogitsWarper):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.regeneration_required = False
|
|
||||||
self.halt = False
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
@ -673,19 +671,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
||||||
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||||
|
|
||||||
if(vars.lua_koboldbridge.generated_cols != 0):
|
|
||||||
for i in range(vars.numseqs):
|
|
||||||
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
|
|
||||||
|
|
||||||
execute_genmod()
|
execute_genmod()
|
||||||
|
|
||||||
if(vars.lua_koboldbridge.regeneration_required):
|
|
||||||
vars.lua_koboldbridge.regeneration_required = False
|
|
||||||
self.regeneration_required = True
|
|
||||||
|
|
||||||
if(not vars.lua_koboldbridge.generating):
|
|
||||||
self.halt = True
|
|
||||||
|
|
||||||
scores = torch.tensor(
|
scores = torch.tensor(
|
||||||
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
||||||
device=scores.device,
|
device=scores.device,
|
||||||
|
@ -758,8 +745,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
) -> bool:
|
) -> bool:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||||
self.regeneration_required = vars.lua_warper.regeneration_required
|
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||||
self.halt = vars.lua_warper.halt
|
self.halt = not vars.lua_koboldbridge.generating
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
|
||||||
|
|
||||||
if(not vars.dynamicscan):
|
if(not vars.dynamicscan):
|
||||||
return self.regeneration_required or self.halt
|
return self.regeneration_required or self.halt
|
||||||
|
|
Loading…
Reference in New Issue