mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Eternal gen work
This commit is contained in:
157
aiserver.py
157
aiserver.py
@@ -2013,6 +2013,49 @@ def patch_transformers():
|
||||
koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
|
||||
return False
|
||||
|
||||
class CoreStopper(StoppingCriteria):
|
||||
# Controls core generation stuff; aborting, counting generated tokens, etc
|
||||
def __init__(self):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
koboldai_vars.generated_tkns += 1
|
||||
|
||||
if (
|
||||
not koboldai_vars.standalone
|
||||
and koboldai_vars.lua_koboldbridge.generated_cols
|
||||
and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols
|
||||
):
|
||||
print("[TODO] Fix generated_cols")
|
||||
# raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
|
||||
|
||||
if koboldai_vars.abort:
|
||||
koboldai_vars.abort = False
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
return True
|
||||
|
||||
if koboldai_vars.standalone:
|
||||
return False
|
||||
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
|
||||
self.halt = not koboldai_vars.lua_koboldbridge.generating
|
||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
for i in range(koboldai_vars.numseqs):
|
||||
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||
|
||||
return self.regeneration_required or self.halt
|
||||
|
||||
|
||||
# Sets up dynamic world info scanner
|
||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||
def __init__(
|
||||
@@ -2024,6 +2067,7 @@ def patch_transformers():
|
||||
self.halt = False
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@@ -2034,35 +2078,38 @@ def patch_transformers():
|
||||
if not koboldai_vars.inference_config.do_dynamic_wi:
|
||||
return False
|
||||
|
||||
koboldai_vars.generated_tkns += 1
|
||||
if(not koboldai_vars.standalone and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
|
||||
if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
koboldai_vars.abort = False
|
||||
return True
|
||||
if(koboldai_vars.standalone):
|
||||
# if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
|
||||
# self.regeneration_required = False
|
||||
# self.halt = False
|
||||
# koboldai_vars.abort = False
|
||||
# return True
|
||||
|
||||
# Pertains to WI I think
|
||||
# if(koboldai_vars.standalone):
|
||||
# return False
|
||||
|
||||
# assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
# self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
|
||||
# self.halt = not koboldai_vars.lua_koboldbridge.generating
|
||||
# koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
# for i in range(koboldai_vars.numseqs):
|
||||
# koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||
|
||||
if not koboldai_vars.dynamicscan:
|
||||
#return self.regeneration_required or self.halt
|
||||
return False
|
||||
|
||||
assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
|
||||
self.halt = not koboldai_vars.lua_koboldbridge.generating
|
||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
for i in range(koboldai_vars.numseqs):
|
||||
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||
|
||||
if(not koboldai_vars.dynamicscan):
|
||||
return self.regeneration_required or self.halt
|
||||
tail = input_ids[..., -koboldai_vars.generated_tkns:]
|
||||
for i, t in enumerate(tail):
|
||||
decoded = utils.decodenewlines(tokenizer.decode(t))
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
|
||||
found -= self.excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
self.regeneration_required = True
|
||||
if len(found) != 0:
|
||||
# self.regeneration_required = True
|
||||
model.core_stopper.regeneration_required = True
|
||||
return True
|
||||
break
|
||||
return self.regeneration_required or self.halt
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
@@ -2070,12 +2117,14 @@ def patch_transformers():
|
||||
global tokenizer
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
|
||||
self.core_stopper = CoreStopper()
|
||||
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
||||
tokenizer=tokenizer,
|
||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||
)
|
||||
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||
|
||||
stopping_criteria.insert(0, self.core_stopper)
|
||||
stopping_criteria.insert(0, self.kai_scanner)
|
||||
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||
stopping_criteria.insert(0, token_streamer)
|
||||
@@ -4811,28 +4860,36 @@ def calcsubmit(txt):
|
||||
def legacy_generate(text: Union[str, list], min: int, max: int):
|
||||
koboldai_vars.lastctx = text
|
||||
|
||||
outputs = tpool.execute(
|
||||
out_batches = tpool.execute(
|
||||
raw_generate,
|
||||
text,
|
||||
max_length=koboldai_vars.genamt,
|
||||
do_streaming=True
|
||||
do_streaming=True,
|
||||
batch_count=koboldai_vars.numseqs,
|
||||
decode=False
|
||||
)
|
||||
|
||||
decoded_batches = tokenizer.batch_decode(out_batches)
|
||||
|
||||
# Lua bridge, genmod
|
||||
for i, output in enumerate(outputs):
|
||||
koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
||||
for i in range(koboldai_vars.numseqs):
|
||||
koboldai_vars.lua_koboldbridge.generated[i + 1][koboldai_vars.generated_tkns] = int(out_batches[i, -1].item())
|
||||
koboldai_vars.lua_koboldbridge.outputs[i + 1] = utils.decodenewlines(tokenizer.decode(out_batches[i, -len(out_batches[i]):]))
|
||||
|
||||
# for i, output in enumerate(outputs):
|
||||
# koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
||||
|
||||
execute_genmod()
|
||||
|
||||
if koboldai_vars.lua_koboldbridge.regeneration_required:
|
||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(len(outputs)):
|
||||
for i in range(len(out_batches)):
|
||||
out = koboldai_vars.lua_koboldbridge.outputs[i + 1]
|
||||
genout.append({"generated_text": out})
|
||||
assert isinstance(out, str)
|
||||
else:
|
||||
genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs]
|
||||
genout = [{"generated_text": utils.decodenewlines(x)} for x in decoded_batches]
|
||||
|
||||
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||
@@ -4855,12 +4912,17 @@ def raw_generate(
|
||||
do_streaming: bool = False,
|
||||
do_dynamic_wi: bool = False,
|
||||
batch_count: int = 1,
|
||||
decode: bool = True,
|
||||
) -> List:
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_tokens = tokenizer.encode(prompt)
|
||||
else:
|
||||
prompt_tokens = prompt
|
||||
|
||||
# Some gen methods such as OAI don't return tokens.
|
||||
batch_decoded = None
|
||||
batch_encoded = None
|
||||
|
||||
if koboldai_vars.model == "Colab":
|
||||
raise NotImplementedError("Colab API raw_generate unsupported")
|
||||
@@ -4874,13 +4936,15 @@ def raw_generate(
|
||||
raise NotImplementedError("No loaded model")
|
||||
|
||||
if koboldai_vars.use_colab_tpu or model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||
batch_out = tpu_raw_generate(
|
||||
batch_encoded = tpu_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_length=max_length,
|
||||
batch_count=batch_count
|
||||
)
|
||||
elif model == "OAI":
|
||||
batch_decoded = ...
|
||||
else:
|
||||
batch_out = torch_raw_generate(
|
||||
batch_encoded = torch_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_length,
|
||||
do_streaming=do_streaming,
|
||||
@@ -4888,9 +4952,21 @@ def raw_generate(
|
||||
batch_count=batch_count
|
||||
)
|
||||
|
||||
decoded = tokenizer.batch_decode(batch_out[:, len(prompt_tokens):])
|
||||
|
||||
return [utils.decodenewlines(x) for x in decoded]
|
||||
assert batch_encoded or batch_decoded
|
||||
|
||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||
if batch_encoded:
|
||||
batch_encoded = batch_encoded[:, len(prompt_tokens):]
|
||||
|
||||
if not decode:
|
||||
return batch_encoded
|
||||
|
||||
if not batch_decoded:
|
||||
batch_decoded = tokenizer.batch_decode(batch_encoded)
|
||||
|
||||
return [utils.decodenewlines(x) for x in batch_decoded]
|
||||
|
||||
def tpu_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
@@ -4961,7 +5037,7 @@ def torch_raw_generate(
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
|
||||
def _generate(txt, minimum, maximum, found_entries):
|
||||
def old_underscore_generate(txt, minimum, maximum, found_entries):
|
||||
if(koboldai_vars.full_determinism):
|
||||
torch.manual_seed(koboldai_vars.seed)
|
||||
|
||||
@@ -5000,19 +5076,30 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
assert already_generated <= koboldai_vars.genamt
|
||||
# If we are halting, we stop
|
||||
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
||||
break
|
||||
# if we require a generation, we continue
|
||||
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == koboldai_vars.numseqs
|
||||
|
||||
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
|
||||
|
||||
if(already_generated != koboldai_vars.generated_tkns):
|
||||
raise RuntimeError("WI scanning error")
|
||||
|
||||
for r in range(koboldai_vars.numseqs):
|
||||
for c in range(already_generated):
|
||||
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
|
||||
|
||||
encoded = []
|
||||
|
||||
# DYNAMIC WI:
|
||||
# IF WE FIND WORLD INFO MID-GENERATION, STOP, THEN ADD WI AND ADD NEW GENERATION
|
||||
|
||||
for i in range(koboldai_vars.numseqs):
|
||||
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
|
||||
@@ -5023,6 +5110,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
else:
|
||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
genout = torch.cat(
|
||||
@@ -5032,6 +5120,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if(koboldai_vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
@@ -5049,7 +5138,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
return genout, already_generated
|
||||
|
||||
|
||||
def generate(txt, minimum, maximum, found_entries=None):
|
||||
def old_generate(txt, minimum, maximum, found_entries=None):
|
||||
koboldai_vars.generated_tkns = 0
|
||||
|
||||
if(found_entries is None):
|
||||
|
Reference in New Issue
Block a user