Merge pull request #234 from one-some/united

Move probability visualization to after logitwarpers
This commit is contained in:
henk717 2022-12-09 21:22:33 +01:00 committed by GitHub
commit 686845cd21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 21 deletions

View File

@ -1936,34 +1936,26 @@ def patch_transformers():
from torch.nn import functional as F
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def visualize_probabilities(scores: torch.FloatTensor) -> None:
assert scores.ndim == 2
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
if vars.numseqs > 1 or not vars.show_probs:
return
if vars.numseqs > 1 or not vars.show_probs:
return scores
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
token_prob_info = []
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
token_prob_info.append({
"tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
"score": float(score),
})
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
token_prob_info = []
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
token_prob_info.append({
"tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
"score": float(score),
})
vars.token_stream_queue.probability_buffer = token_prob_info
return scores
vars.token_stream_queue.probability_buffer = token_prob_info
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
@ -1985,6 +1977,7 @@ def patch_transformers():
sampler_order = [6] + sampler_order
for k in sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
visualize_probabilities(scores)
return scores
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: