Move probability visualization to after logitwarpers
This commit is contained in:
parent
55ef53f39b
commit
e6656d68a1
35
aiserver.py
35
aiserver.py
|
@ -1936,34 +1936,26 @@ def patch_transformers():
|
|||
|
||||
from torch.nn import functional as F
|
||||
|
||||
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
def visualize_probabilities(scores: torch.FloatTensor) -> None:
|
||||
assert scores.ndim == 2
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
if vars.numseqs > 1 or not vars.show_probs:
|
||||
return
|
||||
|
||||
if vars.numseqs > 1 or not vars.show_probs:
|
||||
return scores
|
||||
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
|
||||
token_prob_info = []
|
||||
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
|
||||
token_prob_info.append({
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
|
||||
"score": float(score),
|
||||
})
|
||||
|
||||
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
|
||||
|
||||
token_prob_info = []
|
||||
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
|
||||
token_prob_info.append({
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
|
||||
"score": float(score),
|
||||
})
|
||||
|
||||
vars.token_stream_queue.probability_buffer = token_prob_info
|
||||
return scores
|
||||
vars.token_stream_queue.probability_buffer = token_prob_info
|
||||
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||
processors.insert(0, LuaLogitsProcessor())
|
||||
processors.append(ProbabilityVisualizerLogitsProcessor())
|
||||
return processors
|
||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
@ -1985,6 +1977,7 @@ def patch_transformers():
|
|||
sampler_order = [6] + sampler_order
|
||||
for k in sampler_order:
|
||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
visualize_probabilities(scores)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||
|
|
Loading…
Reference in New Issue