diff --git a/aiserver.py b/aiserver.py index 56874954..1e2cf8ed 100644 --- a/aiserver.py +++ b/aiserver.py @@ -656,8 +656,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme class LuaLogitsWarper(LogitsWarper): def __init__(self): - self.regeneration_required = False - self.halt = False pass def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -673,19 +671,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) vars.lua_koboldbridge.vocab_size = scores_shape[-1] - if(vars.lua_koboldbridge.generated_cols != 0): - for i in range(vars.numseqs): - vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item() - execute_genmod() - if(vars.lua_koboldbridge.regeneration_required): - vars.lua_koboldbridge.regeneration_required = False - self.regeneration_required = True - - if(not vars.lua_koboldbridge.generating): - self.halt = True - scores = torch.tensor( tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), device=scores.device, @@ -758,8 +745,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme ) -> bool: assert input_ids.ndim == 2 assert len(self.excluded_world_info) == input_ids.shape[0] - self.regeneration_required = vars.lua_warper.regeneration_required - self.halt = vars.lua_warper.halt + self.regeneration_required = vars.lua_koboldbridge.regeneration_required + self.halt = not vars.lua_koboldbridge.generating + vars.lua_koboldbridge.regeneration_required = False + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item() if(not vars.dynamicscan): return self.regeneration_required or self.halt