From 09fee52abd2094e3a2f86595437e9716c522c04f Mon Sep 17 00:00:00 2001 From: Wes Brown Date: Thu, 7 Apr 2022 14:50:23 -0400 Subject: [PATCH] Add `num_seqs` support to GooseAI/OpenAI client handler. --- aiserver.py | 48 +++++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/aiserver.py b/aiserver.py index 74b21d21..68713417 100644 --- a/aiserver.py +++ b/aiserver.py @@ -4373,7 +4373,7 @@ def oairequest(txt, min, max): 'repetition_penalty': vars.rep_pen, 'repetition_penalty_slope': vars.rep_pen_slope, 'repetition_penalty_range': vars.rep_pen_range, - 'n': 1, + 'n': vars.numseqs, 'stream': False } else: @@ -4382,7 +4382,7 @@ def oairequest(txt, min, max): 'max_tokens': vars.genamt, 'temperature': vars.temp, 'top_p': vars.top_p, - 'n': 1, + 'n': vars.numseqs, 'stream': False } @@ -4397,24 +4397,27 @@ def oairequest(txt, min, max): # Deal with the response if(req.status_code == 200): - genout = req.json()["choices"][0]["text"] + outputs = [out["text"] for out in req.json()["choices"]] - vars.lua_koboldbridge.outputs[1] = genout + for idx in range(len(outputs)): + vars.lua_koboldbridge.outputs[idx+1] = outputs[idx] execute_outmod() - if(vars.lua_koboldbridge.regeneration_required): + if (vars.lua_koboldbridge.regeneration_required): vars.lua_koboldbridge.regeneration_required = False - genout = vars.lua_koboldbridge.outputs[1] - assert genout is str + genout = [] + for i in range(len(outputs)): + genout.append( + {"generated_text": vars.lua_koboldbridge.outputs[i + 1]}) + assert type(genout[-1]["generated_text"]) is str + else: + genout = [ + {"generated_text": utils.decodenewlines(txt)} + for txt in outputs] - if not vars.quiet: - print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) - vars.actions.append(genout) - # we now need to update the actions_metadata - # we'll have two conditions. - # 1. This is totally new (user entered) if vars.actions.get_last_key() not in vars.actions_metadata: - vars.actions_metadata[vars.actions.get_last_key()] = {"Selected Text": genout, "Alternative Text": []} + vars.actions_metadata[vars.actions.get_last_key()] = { + "Selected Text": genout[0], "Alternative Text": []} else: # 2. We've selected a chunk of text that is was presented previously try: @@ -4427,9 +4430,20 @@ def oairequest(txt, min, max): alternatives = [item for item in vars.actions_metadata[vars.actions.get_last_key() ]["Alternative Text"] if item['Text'] != genout] vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"] = alternatives vars.actions_metadata[vars.actions.get_last_key()]["Selected Text"] = genout - update_story_chunk('last') - emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1 if len(vars.actions) else 0}, broadcast=True) - send_debug() + + if (len(genout) == 1): + genresult(genout[0]["generated_text"]) + else: + if (vars.lua_koboldbridge.restart_sequence is not None and + vars.lua_koboldbridge.restart_sequence > 0): + genresult(genout[vars.lua_koboldbridge.restart_sequence - 1][ + "generated_text"]) + else: + genselect(genout) + + if not vars.quiet: + print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) + set_aibusy(0) else: # Send error message to web client