mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 12:10:49 +01:00
Fix Lua regeneration system
This commit is contained in:
parent
462040ed6f
commit
e5bb20cc8f
21
aiserver.py
21
aiserver.py
@ -656,8 +656,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
class LuaLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
pass
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
@ -673,19 +671,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
||||
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||
|
||||
if(vars.lua_koboldbridge.generated_cols != 0):
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
|
||||
|
||||
execute_genmod()
|
||||
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
self.regeneration_required = True
|
||||
|
||||
if(not vars.lua_koboldbridge.generating):
|
||||
self.halt = True
|
||||
|
||||
scores = torch.tensor(
|
||||
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
||||
device=scores.device,
|
||||
@ -758,8 +745,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
) -> bool:
|
||||
assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.regeneration_required = vars.lua_warper.regeneration_required
|
||||
self.halt = vars.lua_warper.halt
|
||||
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||
self.halt = not vars.lua_koboldbridge.generating
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
|
||||
|
||||
if(not vars.dynamicscan):
|
||||
return self.regeneration_required or self.halt
|
||||
|
Loading…
x
Reference in New Issue
Block a user