mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-26 00:17:41 +01:00
Merge pull request #173 from one-some/token-streaming
Add token streaming option
This commit is contained in:
commit
050e195420
47
aiserver.py
47
aiserver.py
@ -351,6 +351,8 @@ class vars:
|
|||||||
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
|
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
|
||||||
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
|
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
|
||||||
revision = None
|
revision = None
|
||||||
|
output_streaming = False
|
||||||
|
token_stream_queue = [] # Queue for the token streaming
|
||||||
|
|
||||||
utils.vars = vars
|
utils.vars = vars
|
||||||
|
|
||||||
@ -800,6 +802,7 @@ def savesettings():
|
|||||||
js["fulldeterminism"] = vars.full_determinism
|
js["fulldeterminism"] = vars.full_determinism
|
||||||
js["autosave"] = vars.autosave
|
js["autosave"] = vars.autosave
|
||||||
js["welcome"] = vars.welcome
|
js["welcome"] = vars.welcome
|
||||||
|
js["output_streaming"] = vars.output_streaming
|
||||||
|
|
||||||
if(vars.seed_specified):
|
if(vars.seed_specified):
|
||||||
js["seed"] = vars.seed
|
js["seed"] = vars.seed
|
||||||
@ -911,6 +914,8 @@ def processsettings(js):
|
|||||||
vars.newlinemode = js["newlinemode"]
|
vars.newlinemode = js["newlinemode"]
|
||||||
if("welcome" in js):
|
if("welcome" in js):
|
||||||
vars.welcome = js["welcome"]
|
vars.welcome = js["welcome"]
|
||||||
|
if("output_streaming" in js):
|
||||||
|
vars.output_streaming = js["output_streaming"]
|
||||||
|
|
||||||
if("seed" in js):
|
if("seed" in js):
|
||||||
vars.seed = js["seed"]
|
vars.seed = js["seed"]
|
||||||
@ -943,12 +948,20 @@ def processsettings(js):
|
|||||||
|
|
||||||
def check_for_sp_change():
|
def check_for_sp_change():
|
||||||
while(True):
|
while(True):
|
||||||
time.sleep(0.1)
|
time.sleep(0.05)
|
||||||
|
|
||||||
if(vars.sp_changed):
|
if(vars.sp_changed):
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
|
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
|
||||||
vars.sp_changed = False
|
vars.sp_changed = False
|
||||||
|
|
||||||
|
if(vars.output_streaming and vars.token_stream_queue):
|
||||||
|
# If emit blocks, waiting for it to complete before clearing could
|
||||||
|
# introduce a race condition that drops tokens.
|
||||||
|
queued_tokens = list(vars.token_stream_queue)
|
||||||
|
vars.token_stream_queue.clear()
|
||||||
|
socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True)
|
||||||
|
|
||||||
socketio.start_background_task(check_for_sp_change)
|
socketio.start_background_task(check_for_sp_change)
|
||||||
|
|
||||||
def spRequest(filename):
|
def spRequest(filename):
|
||||||
@ -1542,6 +1555,27 @@ def patch_transformers():
|
|||||||
new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__
|
new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__
|
||||||
transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||||
|
|
||||||
|
class TokenStreamer(StoppingCriteria):
|
||||||
|
# A StoppingCriteria is used here because it seems to run after
|
||||||
|
# everything has been evaluated score-wise.
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> bool:
|
||||||
|
# Do not intermingle multiple generations' outputs!
|
||||||
|
if(vars.numseqs > 1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1]))
|
||||||
|
|
||||||
|
vars.token_stream_queue.append(tokenizer_text)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Sets up dynamic world info scanner
|
# Sets up dynamic world info scanner
|
||||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||||
@ -1596,7 +1630,10 @@ def patch_transformers():
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||||
)
|
)
|
||||||
|
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||||
|
|
||||||
stopping_criteria.insert(0, self.kai_scanner)
|
stopping_criteria.insert(0, self.kai_scanner)
|
||||||
|
stopping_criteria.insert(0, token_streamer)
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||||
|
|
||||||
@ -2698,6 +2735,7 @@ def lua_has_setting(setting):
|
|||||||
"rmspch",
|
"rmspch",
|
||||||
"adsnsp",
|
"adsnsp",
|
||||||
"singleline",
|
"singleline",
|
||||||
|
"output_streaming"
|
||||||
)
|
)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@ -2729,6 +2767,7 @@ def lua_get_setting(setting):
|
|||||||
if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"]
|
if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"]
|
||||||
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
|
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
|
||||||
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
|
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
|
||||||
|
if(setting == "output_streaming"): return vars.output_streaming
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set the setting with the given name if it exists
|
# Set the setting with the given name if it exists
|
||||||
@ -2765,6 +2804,7 @@ def lua_set_setting(setting, v):
|
|||||||
if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v
|
if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v
|
||||||
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
|
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
|
||||||
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
|
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
|
||||||
|
if(setting == "output_streaming"): vars.output_streaming = v
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get contents of memory
|
# Get contents of memory
|
||||||
@ -3477,6 +3517,10 @@ def get_message(msg):
|
|||||||
vars.full_determinism = msg['data']
|
vars.full_determinism = msg['data']
|
||||||
settingschanged()
|
settingschanged()
|
||||||
refresh_settings()
|
refresh_settings()
|
||||||
|
elif(msg['cmd'] == 'setoutputstreaming'):
|
||||||
|
vars.output_streaming = msg['data']
|
||||||
|
settingschanged()
|
||||||
|
refresh_settings()
|
||||||
elif(not vars.host and msg['cmd'] == 'importwi'):
|
elif(not vars.host and msg['cmd'] == 'importwi'):
|
||||||
wiimportrequest()
|
wiimportrequest()
|
||||||
elif(msg['cmd'] == 'debug'):
|
elif(msg['cmd'] == 'debug'):
|
||||||
@ -4674,6 +4718,7 @@ def refresh_settings():
|
|||||||
emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': vars.formatoptns["frmtrmspch"]}, broadcast=True)
|
emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': vars.formatoptns["frmtrmspch"]}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True)
|
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True)
|
emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True)
|
||||||
|
emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True)
|
||||||
|
|
||||||
# Allow toggle events again
|
# Allow toggle events again
|
||||||
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True)
|
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True)
|
||||||
|
@ -251,7 +251,18 @@ gensettingstf = [
|
|||||||
"step": 1,
|
"step": 1,
|
||||||
"default": 0,
|
"default": 0,
|
||||||
"tooltip": "Show debug info"
|
"tooltip": "Show debug info"
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"uitype": "toggle",
|
||||||
|
"unit": "bool",
|
||||||
|
"label": "Token Streaming",
|
||||||
|
"id": "setoutputstreaming",
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"step": 1,
|
||||||
|
"default": 0,
|
||||||
|
"tooltip": "Shows outputs to you as they are made. Does not work with more than one gens per action."
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
gensettingsik =[{
|
gensettingsik =[{
|
||||||
|
@ -78,6 +78,7 @@ var rs_accept;
|
|||||||
var rs_close;
|
var rs_close;
|
||||||
var seqselmenu;
|
var seqselmenu;
|
||||||
var seqselcontents;
|
var seqselcontents;
|
||||||
|
var stream_preview;
|
||||||
|
|
||||||
var storyname = null;
|
var storyname = null;
|
||||||
var memorymode = false;
|
var memorymode = false;
|
||||||
@ -103,6 +104,7 @@ var gamestate = "";
|
|||||||
var gamesaved = true;
|
var gamesaved = true;
|
||||||
var modelname = null;
|
var modelname = null;
|
||||||
var model = "";
|
var model = "";
|
||||||
|
var ignore_stream = false;
|
||||||
|
|
||||||
// This is true iff [we're in macOS and the browser is Safari] or [we're in iOS]
|
// This is true iff [we're in macOS and the browser is Safari] or [we're in iOS]
|
||||||
var using_webkit_patch = true;
|
var using_webkit_patch = true;
|
||||||
@ -888,6 +890,7 @@ function formatChunkInnerText(chunk) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function dosubmit(disallow_abort) {
|
function dosubmit(disallow_abort) {
|
||||||
|
ignore_stream = false;
|
||||||
submit_start = Date.now();
|
submit_start = Date.now();
|
||||||
var txt = input_text.val().replace(/\u00a0/g, " ");
|
var txt = input_text.val().replace(/\u00a0/g, " ");
|
||||||
if((disallow_abort || gamestate !== "wait") && !memorymode && !gamestarted && ((!adventure || !action_mode) && txt.trim().length == 0)) {
|
if((disallow_abort || gamestate !== "wait") && !memorymode && !gamestarted && ((!adventure || !action_mode) && txt.trim().length == 0)) {
|
||||||
@ -902,6 +905,7 @@ function dosubmit(disallow_abort) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function _dosubmit() {
|
function _dosubmit() {
|
||||||
|
ignore_stream = false;
|
||||||
var txt = submit_throttle.txt;
|
var txt = submit_throttle.txt;
|
||||||
var disallow_abort = submit_throttle.disallow_abort;
|
var disallow_abort = submit_throttle.disallow_abort;
|
||||||
submit_throttle = null;
|
submit_throttle = null;
|
||||||
@ -2082,6 +2086,15 @@ function unbindGametext() {
|
|||||||
gametext_bound = false;
|
gametext_bound = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function endStream() {
|
||||||
|
// Clear stream, the real text is about to be displayed.
|
||||||
|
ignore_stream = true;
|
||||||
|
if (stream_preview) {
|
||||||
|
stream_preview.remove();
|
||||||
|
stream_preview = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function update_gpu_layers() {
|
function update_gpu_layers() {
|
||||||
var gpu_layers
|
var gpu_layers
|
||||||
gpu_layers = 0;
|
gpu_layers = 0;
|
||||||
@ -2258,6 +2271,21 @@ $(document).ready(function(){
|
|||||||
active_element.focus();
|
active_element.focus();
|
||||||
})();
|
})();
|
||||||
$("body").addClass("connected");
|
$("body").addClass("connected");
|
||||||
|
} else if (msg.cmd == "streamtoken") {
|
||||||
|
// Sometimes the stream_token messages will come in too late, after
|
||||||
|
// we have recieved the full text. This leads to some stray tokens
|
||||||
|
// appearing after the output. To combat this, we only allow tokens
|
||||||
|
// to be displayed after requesting and before recieving text.
|
||||||
|
if (ignore_stream) return;
|
||||||
|
if (!$("#setoutputstreaming")[0].checked) return;
|
||||||
|
|
||||||
|
if (!stream_preview) {
|
||||||
|
stream_preview = document.createElement("span");
|
||||||
|
game_text.append(stream_preview);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream_preview.innerText += msg.data.join("");
|
||||||
|
scrollToBottom();
|
||||||
} else if(msg.cmd == "updatescreen") {
|
} else if(msg.cmd == "updatescreen") {
|
||||||
var _gamestarted = gamestarted;
|
var _gamestarted = gamestarted;
|
||||||
gamestarted = msg.gamestarted;
|
gamestarted = msg.gamestarted;
|
||||||
@ -2333,6 +2361,7 @@ $(document).ready(function(){
|
|||||||
} else if(msg.cmd == "setgamestate") {
|
} else if(msg.cmd == "setgamestate") {
|
||||||
// Enable or Disable buttons
|
// Enable or Disable buttons
|
||||||
if(msg.data == "ready") {
|
if(msg.data == "ready") {
|
||||||
|
endStream();
|
||||||
enableSendBtn();
|
enableSendBtn();
|
||||||
enableButtons([button_actmem, button_actwi, button_actback, button_actfwd, button_actretry]);
|
enableButtons([button_actmem, button_actwi, button_actback, button_actfwd, button_actretry]);
|
||||||
hideWaitAnimation();
|
hideWaitAnimation();
|
||||||
@ -2519,6 +2548,9 @@ $(document).ready(function(){
|
|||||||
} else if(msg.cmd == "updatesingleline") {
|
} else if(msg.cmd == "updatesingleline") {
|
||||||
// Update toggle state
|
// Update toggle state
|
||||||
$("#singleline").prop('checked', msg.data).change();
|
$("#singleline").prop('checked', msg.data).change();
|
||||||
|
} else if(msg.cmd == "updateoutputstreaming") {
|
||||||
|
// Update toggle state
|
||||||
|
$("#setoutputstreaming").prop('checked', msg.data).change();
|
||||||
} else if(msg.cmd == "allowtoggle") {
|
} else if(msg.cmd == "allowtoggle") {
|
||||||
// Allow toggle change states to propagate
|
// Allow toggle change states to propagate
|
||||||
allowtoggle = msg.data;
|
allowtoggle = msg.data;
|
||||||
@ -2914,6 +2946,7 @@ $(document).ready(function(){
|
|||||||
});
|
});
|
||||||
|
|
||||||
button_actretry.on("click", function(ev) {
|
button_actretry.on("click", function(ev) {
|
||||||
|
ignore_stream = false;
|
||||||
hideMessage();
|
hideMessage();
|
||||||
socket.send({'cmd': 'retry', 'chatname': chatmode ? chat_name.val() : undefined, 'data': ''});
|
socket.send({'cmd': 'retry', 'chatname': chatmode ? chat_name.val() : undefined, 'data': ''});
|
||||||
hidegenseqs();
|
hidegenseqs();
|
||||||
@ -3160,6 +3193,7 @@ $(document).ready(function(){
|
|||||||
});
|
});
|
||||||
|
|
||||||
rs_accept.on("click", function(ev) {
|
rs_accept.on("click", function(ev) {
|
||||||
|
ignore_stream = false;
|
||||||
hideMessage();
|
hideMessage();
|
||||||
socket.send({'cmd': 'rndgame', 'memory': $("#rngmemory").val(), 'data': topic.val()});
|
socket.send({'cmd': 'rndgame', 'memory': $("#rngmemory").val(), 'data': topic.val()});
|
||||||
hideRandomStoryPopup();
|
hideRandomStoryPopup();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user