From b7e38b47570cb910d4b5b9c853985e6d3fba9107 Mon Sep 17 00:00:00 2001 From: Llama <34464159+pi6am@users.noreply.github.com> Date: Sat, 26 Aug 2023 22:26:26 -0700 Subject: [PATCH] Resample to work around a bug in torch.multinomial There is a bug in PyTorch 2.0.1 that allows torch.multinomial to sometimes choose elements that have zero probability. Since this is uncommon we can continue to use torch.multinomial as long as we verify that the results are valid. If they aren't, try again until the probability of each selected token is positive. --- modeling/inference_models/exllama/class.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modeling/inference_models/exllama/class.py b/modeling/inference_models/exllama/class.py index 2540d3f4..3fb8d252 100644 --- a/modeling/inference_models/exllama/class.py +++ b/modeling/inference_models/exllama/class.py @@ -293,7 +293,14 @@ class model_backend(InferenceModel): scores = torch.softmax(scores, dim=-1) - token = torch.multinomial(scores, 1) + # Work around a bug in torch.multinomial (https://github.com/pytorch/pytorch/issues/48841) + # With low probability, multinomial can return an element with zero weight. Since this + # happens infrequently, just sample repeatedly until all tokens have non-zero probability. + for _ in range(100): + token = torch.multinomial(scores, 1) + # Verify that all selected tokens correspond to positive probabilities. + if (scores.gather(1, token) > 0).all(): + break self.generator.gen_accept_token(token) @@ -301,7 +308,7 @@ class model_backend(InferenceModel): utils.koboldai_vars.generated_tkns += 1 - if token.item() == self.tokenizer.eos_token_id: + if (token == self.tokenizer.eos_token_id).any(): trim_count = 1 break