From 7ed48652dd1aae38a75b7395f8457bb6fe0aef71 Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 17 Aug 2022 14:13:04 -0500 Subject: [PATCH] Fix probabilities --- aiserver.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/aiserver.py b/aiserver.py index ec628535..3347ea7a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1717,20 +1717,24 @@ def patch_transformers(): assert scores.ndim == 2 assert input_ids.ndim == 2 - if koboldai_vars.numseqs > 1 or not koboldai_vars.show_probs: + if not koboldai_vars.show_probs: return scores - probs = F.softmax(scores, dim = -1).cpu().numpy()[0] + for batch_index, batch in enumerate(scores): + probs = F.softmax(batch, dim = -1).cpu().numpy() - token_prob_info = [] - for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]: - token_prob_info.append({ - "tokenId": token_id, - "decoded": utils.decodenewlines(tokenizer.decode(token_id)), - "score": float(score), - }) + token_prob_info = [] + for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]: + token_prob_info.append({ + "tokenId": token_id, + "decoded": utils.decodenewlines(tokenizer.decode(token_id)), + "score": float(score), + }) - koboldai_vars.token_stream_queue.probability_buffer = token_prob_info + if len(scores) == 1: + koboldai_vars.actions.set_probabilities(token_prob_info) + else: + koboldai_vars.actions.set_option_probabilities(token_prob_info, batch_index) return scores def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: