Model: Add debug code for detecting faulty samplers

This commit is contained in:
somebody
2023-03-04 18:56:41 -06:00
parent c7822464c7
commit beef23f5a1

View File

@@ -48,6 +48,10 @@ except ModuleNotFoundError as e:
if not utils.koboldai_vars.use_colab_tpu: if not utils.koboldai_vars.use_colab_tpu:
raise e raise e
# When set to true, messages will appear in the console if samplers are not
# changing the scores. Keep in mind some samplers don't always change the
# scores for each token.
LOG_SAMPLER_NO_EFFECT = False
class HFTorchInferenceModel(HFInferenceModel): class HFTorchInferenceModel(HFInferenceModel):
def __init__( def __init__(
@@ -87,6 +91,10 @@ class HFTorchInferenceModel(HFInferenceModel):
self, scores: torch.Tensor, input_ids: torch.Tensor self, scores: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
warpers.update_settings() warpers.update_settings()
if LOG_SAMPLER_NO_EFFECT:
pre = torch.Tensor(scores)
for sid in utils.koboldai_vars.sampler_order: for sid in utils.koboldai_vars.sampler_order:
warper = Warper.from_id(sid) warper = Warper.from_id(sid)
if warper == warpers.RepetitionPenalty: if warper == warpers.RepetitionPenalty:
@@ -94,6 +102,11 @@ class HFTorchInferenceModel(HFInferenceModel):
scores = warper.torch(scores, input_ids=input_ids) scores = warper.torch(scores, input_ids=input_ids)
else: else:
scores = warper.torch(scores) scores = warper.torch(scores)
if LOG_SAMPLER_NO_EFFECT:
if torch.equal(pre, scores):
logger.info(warper, "had no effect on the scores.")
pre = torch.Tensor(scores)
return scores return scores
def _post_load(self) -> None: def _post_load(self) -> None: