Move probability visualization to after logitwarpers
This commit is contained in:
parent
55ef53f39b
commit
e6656d68a1
35
aiserver.py
35
aiserver.py
|
@ -1936,34 +1936,26 @@ def patch_transformers():
|
||||||
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor):
|
def visualize_probabilities(scores: torch.FloatTensor) -> None:
|
||||||
def __init__(self):
|
assert scores.ndim == 2
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
if vars.numseqs > 1 or not vars.show_probs:
|
||||||
assert scores.ndim == 2
|
return
|
||||||
assert input_ids.ndim == 2
|
|
||||||
|
|
||||||
if vars.numseqs > 1 or not vars.show_probs:
|
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
|
||||||
return scores
|
token_prob_info = []
|
||||||
|
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
|
||||||
|
token_prob_info.append({
|
||||||
|
"tokenId": token_id,
|
||||||
|
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
|
||||||
|
"score": float(score),
|
||||||
|
})
|
||||||
|
|
||||||
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
|
vars.token_stream_queue.probability_buffer = token_prob_info
|
||||||
|
|
||||||
token_prob_info = []
|
|
||||||
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
|
|
||||||
token_prob_info.append({
|
|
||||||
"tokenId": token_id,
|
|
||||||
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
|
|
||||||
"score": float(score),
|
|
||||||
})
|
|
||||||
|
|
||||||
vars.token_stream_queue.probability_buffer = token_prob_info
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||||
processors.insert(0, LuaLogitsProcessor())
|
processors.insert(0, LuaLogitsProcessor())
|
||||||
processors.append(ProbabilityVisualizerLogitsProcessor())
|
|
||||||
return processors
|
return processors
|
||||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||||
|
@ -1985,6 +1977,7 @@ def patch_transformers():
|
||||||
sampler_order = [6] + sampler_order
|
sampler_order = [6] + sampler_order
|
||||||
for k in sampler_order:
|
for k in sampler_order:
|
||||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||||
|
visualize_probabilities(scores)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||||
|
|
Loading…
Reference in New Issue