Merge pull request #110 from gooseai/united.add-oai-numseqs-support

Add `numseqs` support to GooseAI/OpenAI client handler.
This commit is contained in:
henk717 2022-04-07 20:57:10 +02:00 committed by GitHub
commit 47e825c83c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 17 deletions

View File

@ -4373,7 +4373,7 @@ def oairequest(txt, min, max):
'repetition_penalty': vars.rep_pen, 'repetition_penalty': vars.rep_pen,
'repetition_penalty_slope': vars.rep_pen_slope, 'repetition_penalty_slope': vars.rep_pen_slope,
'repetition_penalty_range': vars.rep_pen_range, 'repetition_penalty_range': vars.rep_pen_range,
'n': 1, 'n': vars.numseqs,
'stream': False 'stream': False
} }
else: else:
@ -4382,7 +4382,7 @@ def oairequest(txt, min, max):
'max_tokens': vars.genamt, 'max_tokens': vars.genamt,
'temperature': vars.temp, 'temperature': vars.temp,
'top_p': vars.top_p, 'top_p': vars.top_p,
'n': 1, 'n': vars.numseqs,
'stream': False 'stream': False
} }
@ -4397,24 +4397,27 @@ def oairequest(txt, min, max):
# Deal with the response # Deal with the response
if(req.status_code == 200): if(req.status_code == 200):
genout = req.json()["choices"][0]["text"] outputs = [out["text"] for out in req.json()["choices"]]
vars.lua_koboldbridge.outputs[1] = genout for idx in range(len(outputs)):
vars.lua_koboldbridge.outputs[idx+1] = outputs[idx]
execute_outmod() execute_outmod()
if(vars.lua_koboldbridge.regeneration_required): if (vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False vars.lua_koboldbridge.regeneration_required = False
genout = vars.lua_koboldbridge.outputs[1] genout = []
assert genout is str for i in range(len(outputs)):
genout.append(
{"generated_text": vars.lua_koboldbridge.outputs[i + 1]})
assert type(genout[-1]["generated_text"]) is str
else:
genout = [
{"generated_text": utils.decodenewlines(txt)}
for txt in outputs]
if not vars.quiet:
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
vars.actions.append(genout)
# we now need to update the actions_metadata
# we'll have two conditions.
# 1. This is totally new (user entered)
if vars.actions.get_last_key() not in vars.actions_metadata: if vars.actions.get_last_key() not in vars.actions_metadata:
vars.actions_metadata[vars.actions.get_last_key()] = {"Selected Text": genout, "Alternative Text": []} vars.actions_metadata[vars.actions.get_last_key()] = {
"Selected Text": genout[0], "Alternative Text": []}
else: else:
# 2. We've selected a chunk of text that is was presented previously # 2. We've selected a chunk of text that is was presented previously
try: try:
@ -4427,9 +4430,20 @@ def oairequest(txt, min, max):
alternatives = [item for item in vars.actions_metadata[vars.actions.get_last_key() ]["Alternative Text"] if item['Text'] != genout] alternatives = [item for item in vars.actions_metadata[vars.actions.get_last_key() ]["Alternative Text"] if item['Text'] != genout]
vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"] = alternatives vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"] = alternatives
vars.actions_metadata[vars.actions.get_last_key()]["Selected Text"] = genout vars.actions_metadata[vars.actions.get_last_key()]["Selected Text"] = genout
update_story_chunk('last')
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1 if len(vars.actions) else 0}, broadcast=True) if (len(genout) == 1):
send_debug() genresult(genout[0]["generated_text"])
else:
if (vars.lua_koboldbridge.restart_sequence is not None and
vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence - 1][
"generated_text"])
else:
genselect(genout)
if not vars.quiet:
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
set_aibusy(0) set_aibusy(0)
else: else:
# Send error message to web client # Send error message to web client