Merge pull request #191 from henk717/united

Probability Viewer Fix
This commit is contained in:
henk717 2022-12-09 21:56:41 +01:00 committed by GitHub
commit dd7363548c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1936,34 +1936,26 @@ def patch_transformers():
from torch.nn import functional as F from torch.nn import functional as F
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor): def visualize_probabilities(scores: torch.FloatTensor) -> None:
def __init__(self): assert scores.ndim == 2
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if vars.numseqs > 1 or not vars.show_probs:
assert scores.ndim == 2 return
assert input_ids.ndim == 2
if vars.numseqs > 1 or not vars.show_probs: probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
return scores token_prob_info = []
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
token_prob_info.append({
"tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
"score": float(score),
})
probs = F.softmax(scores, dim = -1).cpu().numpy()[0] vars.token_stream_queue.probability_buffer = token_prob_info
token_prob_info = []
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
token_prob_info.append({
"tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
"score": float(score),
})
vars.token_stream_queue.probability_buffer = token_prob_info
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor()) processors.insert(0, LuaLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
@ -1985,6 +1977,7 @@ def patch_transformers():
sampler_order = [6] + sampler_order sampler_order = [6] + sampler_order
for k in sampler_order: for k in sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
visualize_probabilities(scores)
return scores return scores
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: