mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #173 from one-some/token-streaming
Add token streaming option
This commit is contained in:
47
aiserver.py
47
aiserver.py
@ -351,6 +351,8 @@ class vars:
|
||||
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
|
||||
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
|
||||
revision = None
|
||||
output_streaming = False
|
||||
token_stream_queue = [] # Queue for the token streaming
|
||||
|
||||
utils.vars = vars
|
||||
|
||||
@ -800,6 +802,7 @@ def savesettings():
|
||||
js["fulldeterminism"] = vars.full_determinism
|
||||
js["autosave"] = vars.autosave
|
||||
js["welcome"] = vars.welcome
|
||||
js["output_streaming"] = vars.output_streaming
|
||||
|
||||
if(vars.seed_specified):
|
||||
js["seed"] = vars.seed
|
||||
@ -911,6 +914,8 @@ def processsettings(js):
|
||||
vars.newlinemode = js["newlinemode"]
|
||||
if("welcome" in js):
|
||||
vars.welcome = js["welcome"]
|
||||
if("output_streaming" in js):
|
||||
vars.output_streaming = js["output_streaming"]
|
||||
|
||||
if("seed" in js):
|
||||
vars.seed = js["seed"]
|
||||
@ -943,12 +948,20 @@ def processsettings(js):
|
||||
|
||||
def check_for_sp_change():
|
||||
while(True):
|
||||
time.sleep(0.1)
|
||||
time.sleep(0.05)
|
||||
|
||||
if(vars.sp_changed):
|
||||
with app.app_context():
|
||||
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
|
||||
vars.sp_changed = False
|
||||
|
||||
if(vars.output_streaming and vars.token_stream_queue):
|
||||
# If emit blocks, waiting for it to complete before clearing could
|
||||
# introduce a race condition that drops tokens.
|
||||
queued_tokens = list(vars.token_stream_queue)
|
||||
vars.token_stream_queue.clear()
|
||||
socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True)
|
||||
|
||||
socketio.start_background_task(check_for_sp_change)
|
||||
|
||||
def spRequest(filename):
|
||||
@ -1542,6 +1555,27 @@ def patch_transformers():
|
||||
new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__
|
||||
transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||
|
||||
class TokenStreamer(StoppingCriteria):
|
||||
# A StoppingCriteria is used here because it seems to run after
|
||||
# everything has been evaluated score-wise.
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
# Do not intermingle multiple generations' outputs!
|
||||
if(vars.numseqs > 1):
|
||||
return False
|
||||
|
||||
tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1]))
|
||||
|
||||
vars.token_stream_queue.append(tokenizer_text)
|
||||
return False
|
||||
|
||||
|
||||
# Sets up dynamic world info scanner
|
||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||
@ -1596,7 +1630,10 @@ def patch_transformers():
|
||||
tokenizer=tokenizer,
|
||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||
)
|
||||
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||
|
||||
stopping_criteria.insert(0, self.kai_scanner)
|
||||
stopping_criteria.insert(0, token_streamer)
|
||||
return stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||
|
||||
@ -2698,6 +2735,7 @@ def lua_has_setting(setting):
|
||||
"rmspch",
|
||||
"adsnsp",
|
||||
"singleline",
|
||||
"output_streaming"
|
||||
)
|
||||
|
||||
#==================================================================#
|
||||
@ -2729,6 +2767,7 @@ def lua_get_setting(setting):
|
||||
if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"]
|
||||
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
|
||||
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
|
||||
if(setting == "output_streaming"): return vars.output_streaming
|
||||
|
||||
#==================================================================#
|
||||
# Set the setting with the given name if it exists
|
||||
@ -2765,6 +2804,7 @@ def lua_set_setting(setting, v):
|
||||
if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v
|
||||
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
|
||||
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
|
||||
if(setting == "output_streaming"): vars.output_streaming = v
|
||||
|
||||
#==================================================================#
|
||||
# Get contents of memory
|
||||
@ -3477,6 +3517,10 @@ def get_message(msg):
|
||||
vars.full_determinism = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setoutputstreaming'):
|
||||
vars.output_streaming = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(not vars.host and msg['cmd'] == 'importwi'):
|
||||
wiimportrequest()
|
||||
elif(msg['cmd'] == 'debug'):
|
||||
@ -4674,6 +4718,7 @@ def refresh_settings():
|
||||
emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': vars.formatoptns["frmtrmspch"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True)
|
||||
|
||||
# Allow toggle events again
|
||||
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True)
|
||||
|
Reference in New Issue
Block a user