mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #64 from pi6am/fix/multinomial-workaround
Resample to work around a bug in torch.multinomial
This commit is contained in:
@@ -293,7 +293,14 @@ class model_backend(InferenceModel):
|
|||||||
|
|
||||||
scores = torch.softmax(scores, dim=-1)
|
scores = torch.softmax(scores, dim=-1)
|
||||||
|
|
||||||
token = torch.multinomial(scores, 1)
|
# Work around a bug in torch.multinomial (https://github.com/pytorch/pytorch/issues/48841)
|
||||||
|
# With low probability, multinomial can return an element with zero weight. Since this
|
||||||
|
# happens infrequently, just sample repeatedly until all tokens have non-zero probability.
|
||||||
|
for _ in range(100):
|
||||||
|
token = torch.multinomial(scores, 1)
|
||||||
|
# Verify that all selected tokens correspond to positive probabilities.
|
||||||
|
if (scores.gather(1, token) > 0).all():
|
||||||
|
break
|
||||||
|
|
||||||
self.generator.gen_accept_token(token)
|
self.generator.gen_accept_token(token)
|
||||||
|
|
||||||
@@ -301,7 +308,7 @@ class model_backend(InferenceModel):
|
|||||||
|
|
||||||
utils.koboldai_vars.generated_tkns += 1
|
utils.koboldai_vars.generated_tkns += 1
|
||||||
|
|
||||||
if token.item() == self.tokenizer.eos_token_id:
|
if (token == self.tokenizer.eos_token_id).any():
|
||||||
trim_count = 1
|
trim_count = 1
|
||||||
break
|
break
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user