Merge pull request #178 from one-some/token-prob

Add token probability visualizer
This commit is contained in:
henk717
2022-08-05 14:27:46 +02:00
committed by GitHub
5 changed files with 211 additions and 16 deletions

View File

@ -215,6 +215,19 @@ model_menu = {
["Return to Main Menu", "mainmenu", "", True],
]
}
class TokenStreamQueue:
def __init__(self):
self.probability_buffer = None
self.queue = []
def add_text(self, text):
self.queue.append({
"decoded": text,
"probabilities": self.probability_buffer
})
self.probability_buffer = None
# Variables
class vars:
lastact = "" # The last action received from the user
@ -352,7 +365,8 @@ class vars:
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
revision = None
output_streaming = False
token_stream_queue = [] # Queue for the token streaming
token_stream_queue = TokenStreamQueue() # Queue for the token streaming
show_probs = False # Whether or not to show token probabilities
utils.vars = vars
@ -803,6 +817,7 @@ def savesettings():
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["output_streaming"] = vars.output_streaming
js["show_probs"] = vars.show_probs
if(vars.seed_specified):
js["seed"] = vars.seed
@ -916,6 +931,8 @@ def processsettings(js):
vars.welcome = js["welcome"]
if("output_streaming" in js):
vars.output_streaming = js["output_streaming"]
if("show_probs" in js):
vars.show_probs = js["show_probs"]
if("seed" in js):
vars.seed = js["seed"]
@ -955,11 +972,11 @@ def check_for_sp_change():
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
vars.sp_changed = False
if(vars.output_streaming and vars.token_stream_queue):
if(vars.token_stream_queue.queue):
# If emit blocks, waiting for it to complete before clearing could
# introduce a race condition that drops tokens.
queued_tokens = list(vars.token_stream_queue)
vars.token_stream_queue.clear()
queued_tokens = list(vars.token_stream_queue.queue)
vars.token_stream_queue.queue.clear()
socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True)
socketio.start_background_task(check_for_sp_change)
@ -1509,10 +1526,37 @@ def patch_transformers():
assert scores.shape == scores_shape
return scores
from torch.nn import functional as F
class ProbabilityVisualizerLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
if vars.numseqs > 1 or not vars.show_probs:
return scores
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
token_prob_info = []
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
token_prob_info.append({
"tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
"score": float(score),
})
vars.token_stream_queue.probability_buffer = token_prob_info
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
@ -1568,12 +1612,14 @@ def patch_transformers():
**kwargs,
) -> bool:
# Do not intermingle multiple generations' outputs!
if(vars.numseqs > 1):
if vars.numseqs > 1:
return False
if not (vars.show_probs or vars.output_streaming):
return False
tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1]))
vars.token_stream_queue.append(tokenizer_text)
vars.token_stream_queue.add_text(tokenizer_text)
return False
@ -2760,7 +2806,8 @@ def lua_has_setting(setting):
"rmspch",
"adsnsp",
"singleline",
"output_streaming"
"output_streaming",
"show_probs"
)
#==================================================================#
@ -2793,6 +2840,7 @@ def lua_get_setting(setting):
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
if(setting == "output_streaming"): return vars.output_streaming
if(setting == "show_probs"): return vars.show_probs
#==================================================================#
# Set the setting with the given name if it exists
@ -2830,6 +2878,7 @@ def lua_set_setting(setting, v):
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
if(setting == "output_streaming"): vars.output_streaming = v
if(setting == "show_probs"): vars.show_probs = v
#==================================================================#
# Get contents of memory
@ -3546,6 +3595,10 @@ def get_message(msg):
vars.output_streaming = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setshowprobs'):
vars.show_probs = msg['data']
settingschanged()
refresh_settings()
elif(not vars.host and msg['cmd'] == 'importwi'):
wiimportrequest()
elif(msg['cmd'] == 'debug'):
@ -4744,6 +4797,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True)
emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True)
emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True)
emit('from_server', {'cmd': 'updateshowprobs', 'data': vars.show_probs}, broadcast=True)
# Allow toggle events again
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True)