mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
And the rest
This commit is contained in:
44
aiserver.py
44
aiserver.py
@@ -9433,41 +9433,65 @@ def UI_2_generate_wi(data):
|
||||
field = data["field"]
|
||||
existing = data["existing"]
|
||||
gen_amount = data["genAmount"]
|
||||
print(uid, field)
|
||||
|
||||
with open("data/wi_fewshot.txt", "r") as file:
|
||||
fewshot_template = file.read()
|
||||
|
||||
# The template to coax what we want from the model
|
||||
extractor_string = ""
|
||||
|
||||
if field == "title":
|
||||
prompt = fewshot_template
|
||||
for thing in ["type", "desc"]:
|
||||
if not existing[thing]:
|
||||
continue
|
||||
pretty = {"type": "Type", "desc": "Description"}[thing]
|
||||
prompt += f"{pretty}: {existing[thing]}\n"
|
||||
extractor_string += f"{pretty}: {existing[thing]}\n"
|
||||
|
||||
pretty = "Title"
|
||||
if existing["desc"]:
|
||||
# Don't let the model think we're starting a new entry
|
||||
pretty = "Alternate Title"
|
||||
|
||||
prompt += pretty + ":"
|
||||
extractor_string += pretty + ":"
|
||||
elif field == "desc":
|
||||
# MUST be title and type
|
||||
assert existing["title"]
|
||||
assert existing["type"]
|
||||
prompt = f"{fewshot_template}Title: {existing['title']}\nType: {existing['type']}\nDescription:"
|
||||
extractor_string = f"Title: {existing['title']}\nType: {existing['type']}\nDescription:"
|
||||
else:
|
||||
assert False, "What"
|
||||
|
||||
with open("data/wi_fewshot.txt", "r") as file:
|
||||
fewshot_entries = [x.strip() for x in file.read().split("\n\n") if x]
|
||||
|
||||
# Use user's own WI entries in prompt
|
||||
if koboldai_vars.wigen_use_own_wi:
|
||||
fewshot_entries += koboldai_vars.worldinfo_v2.to_wi_fewshot_format(excluding_uid=uid)
|
||||
|
||||
logger.info(prompt)
|
||||
# We must have this amount or less in our context.
|
||||
target = koboldai_vars.max_length - gen_amount - len(tokenizer.encode(extractor_string))
|
||||
|
||||
used = []
|
||||
# Walk the entries backwards until we can't cram anymore in
|
||||
for entry in reversed(fewshot_entries):
|
||||
maybe = [entry] + used
|
||||
maybe_str = "\n\n".join(maybe)
|
||||
possible_encoded = tokenizer.encode(maybe_str)
|
||||
if len(possible_encoded) > target:
|
||||
break
|
||||
yes_str = maybe_str
|
||||
used = maybe
|
||||
|
||||
prompt = f"{yes_str}\n\n{extractor_string}"
|
||||
|
||||
# logger.info(prompt)
|
||||
# TODO: Make single_line mode that stops on newline rather than bans it (for title)
|
||||
out_text = tpool.execute(
|
||||
raw_generate,
|
||||
prompt,
|
||||
max_new=gen_amount,
|
||||
single_line=True,
|
||||
).decoded[0]
|
||||
print(f'{out_text}')
|
||||
out_text = utils.trimincompletesentence(out_text.strip())
|
||||
|
||||
socketio.emit("generated_wi", {"uid": uid, "field": field, "out": out_text}, room="UI_2")
|
||||
|
||||
@app.route("/generate_raw", methods=["GET"])
|
||||
def UI_2_generate_raw():
|
||||
|
Reference in New Issue
Block a user