mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Support for multiple gens per action with dynamic scan
This commit is contained in:
47
aiserver.py
47
aiserver.py
@ -14,7 +14,7 @@ from tkinter import messagebox
|
||||
import json
|
||||
import collections
|
||||
import zipfile
|
||||
from typing import Union, Dict, Set
|
||||
from typing import Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
import html
|
||||
@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
excluded_world_info: set,
|
||||
#head_length: torch.LongTensor,
|
||||
excluded_world_info: List[Set],
|
||||
head_length: int,
|
||||
):
|
||||
self.any_new_entries = False
|
||||
@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
assert input_ids.ndim == 2
|
||||
#assert input_ids.shape[:-1] == self.head_length.shape
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.any_new_entries = False
|
||||
if(not vars.dynamicscan):
|
||||
return False
|
||||
tail = input_ids[..., self.head_length:]
|
||||
for t in tail:
|
||||
for i, t in enumerate(tail):
|
||||
decoded = tokenizer.decode(t)
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
found -= self.excluded_world_info
|
||||
found -= self.excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
self.any_new_entries = True
|
||||
break
|
||||
@ -1423,8 +1422,12 @@ def calcsubmit(txt):
|
||||
#==================================================================#
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
def generate(txt, min, max, found_entries=set()):
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
|
||||
def generate(txt, minimum, maximum, found_entries=None):
|
||||
if(found_entries is None):
|
||||
found_entries = set()
|
||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
vars.lastctx = txt
|
||||
@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()):
|
||||
|
||||
with torch.no_grad():
|
||||
already_generated = 0
|
||||
numseqs = vars.numseqs if not vars.dynamicscan else 1
|
||||
numseqs = vars.numseqs
|
||||
while True:
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max-already_generated,
|
||||
min_length=minimum,
|
||||
max_length=maximum-already_generated,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()):
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
if(not model.kai_scanner.any_new_entries):
|
||||
break
|
||||
txt = tokenizer.decode(genout[0, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries |= _found_entries
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == vars.numseqs
|
||||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
genout = torch.cat(
|
||||
(
|
||||
encoded,
|
||||
@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()):
|
||||
model.config.vocab_size + vars.sp.shape[0],
|
||||
device=genout.device,
|
||||
)
|
||||
genout = torch.cat((soft_tokens[None], genout), dim=-1)
|
||||
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
|
||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||
min += diff
|
||||
max += diff
|
||||
minimum += diff
|
||||
maximum += diff
|
||||
gen_in = genout
|
||||
model.kai_scanner_head_length = encoded.shape[-1]
|
||||
numseqs = 1
|
||||
|
Reference in New Issue
Block a user