And the rest

This commit is contained in:
somebody
2022-12-12 19:19:28 -06:00
parent 62d4ed8f3e
commit 3b23d1f9c8
6 changed files with 209 additions and 13 deletions

View File

@@ -9433,41 +9433,65 @@ def UI_2_generate_wi(data):
field = data["field"]
existing = data["existing"]
gen_amount = data["genAmount"]
print(uid, field)
with open("data/wi_fewshot.txt", "r") as file:
fewshot_template = file.read()
# The template to coax what we want from the model
extractor_string = ""
if field == "title":
prompt = fewshot_template
for thing in ["type", "desc"]:
if not existing[thing]:
continue
pretty = {"type": "Type", "desc": "Description"}[thing]
prompt += f"{pretty}: {existing[thing]}\n"
extractor_string += f"{pretty}: {existing[thing]}\n"
pretty = "Title"
if existing["desc"]:
# Don't let the model think we're starting a new entry
pretty = "Alternate Title"
prompt += pretty + ":"
extractor_string += pretty + ":"
elif field == "desc":
# MUST be title and type
assert existing["title"]
assert existing["type"]
prompt = f"{fewshot_template}Title: {existing['title']}\nType: {existing['type']}\nDescription:"
extractor_string = f"Title: {existing['title']}\nType: {existing['type']}\nDescription:"
else:
assert False, "What"
with open("data/wi_fewshot.txt", "r") as file:
fewshot_entries = [x.strip() for x in file.read().split("\n\n") if x]
# Use user's own WI entries in prompt
if koboldai_vars.wigen_use_own_wi:
fewshot_entries += koboldai_vars.worldinfo_v2.to_wi_fewshot_format(excluding_uid=uid)
logger.info(prompt)
# We must have this amount or less in our context.
target = koboldai_vars.max_length - gen_amount - len(tokenizer.encode(extractor_string))
used = []
# Walk the entries backwards until we can't cram anymore in
for entry in reversed(fewshot_entries):
maybe = [entry] + used
maybe_str = "\n\n".join(maybe)
possible_encoded = tokenizer.encode(maybe_str)
if len(possible_encoded) > target:
break
yes_str = maybe_str
used = maybe
prompt = f"{yes_str}\n\n{extractor_string}"
# logger.info(prompt)
# TODO: Make single_line mode that stops on newline rather than bans it (for title)
out_text = tpool.execute(
raw_generate,
prompt,
max_new=gen_amount,
single_line=True,
).decoded[0]
print(f'{out_text}')
out_text = utils.trimincompletesentence(out_text.strip())
socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2")
@app.route("/generate_raw", methods=["GET"])
def UI_2_generate_raw():