Move probability visualization to after logitwarpers

This commit is contained in:
somebody 2022-12-09 13:45:45 -06:00
parent 55ef53f39b
commit e6656d68a1
1 changed files with 14 additions and 21 deletions

View File

@ -1936,19 +1936,13 @@ def patch_transformers():
from torch.nn import functional as F
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
def visualize_probabilities(scores: torch.FloatTensor) -> None:
assert scores.ndim == 2
assert input_ids.ndim == 2
if vars.numseqs > 1 or not vars.show_probs:
return scores
return
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
token_prob_info = []
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
token_prob_info.append({
@ -1958,12 +1952,10 @@ def patch_transformers():
})
vars.token_stream_queue.probability_buffer = token_prob_info
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
@ -1985,6 +1977,7 @@ def patch_transformers():
sampler_order = [6] + sampler_order
for k in sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
visualize_probabilities(scores)
return scores
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: