From e6656d68a102b0d00ec724d67688e7e8699ca784 Mon Sep 17 00:00:00 2001 From: somebody Date: Fri, 9 Dec 2022 13:45:45 -0600 Subject: [PATCH] Move probability visualization to after logitwarpers --- aiserver.py | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/aiserver.py b/aiserver.py index eff21923..f3117604 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1936,34 +1936,26 @@ def patch_transformers(): from torch.nn import functional as F - class ProbabilityVisualizerLogitsProcessor(LogitsProcessor): - def __init__(self): - pass + def visualize_probabilities(scores: torch.FloatTensor) -> None: + assert scores.ndim == 2 - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - assert scores.ndim == 2 - assert input_ids.ndim == 2 + if vars.numseqs > 1 or not vars.show_probs: + return - if vars.numseqs > 1 or not vars.show_probs: - return scores + probs = F.softmax(scores, dim = -1).cpu().numpy()[0] + token_prob_info = [] + for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]: + token_prob_info.append({ + "tokenId": token_id, + "decoded": utils.decodenewlines(tokenizer.decode(token_id)), + "score": float(score), + }) - probs = F.softmax(scores, dim = -1).cpu().numpy()[0] - - token_prob_info = [] - for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]: - token_prob_info.append({ - "tokenId": token_id, - "decoded": utils.decodenewlines(tokenizer.decode(token_id)), - "score": float(score), - }) - - vars.token_stream_queue.probability_buffer = token_prob_info - return scores + vars.token_stream_queue.probability_buffer = token_prob_info def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) processors.insert(0, LuaLogitsProcessor()) - processors.append(ProbabilityVisualizerLogitsProcessor()) return processors new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor @@ -1985,6 +1977,7 @@ def patch_transformers(): sampler_order = [6] + sampler_order for k in sampler_order: scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) + visualize_probabilities(scores) return scores def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: