Eternal gen work

This commit is contained in:
somebody
2022-09-18 22:30:43 -05:00
parent 386477e59c
commit 2ed7a9dcee

View File

@@ -2013,6 +2013,49 @@ def patch_transformers():
koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
return False
class CoreStopper(StoppingCriteria):
# Controls core generation stuff; aborting, counting generated tokens, etc
def __init__(self):
self.regeneration_required = False
self.halt = False
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
koboldai_vars.generated_tkns += 1
if (
not koboldai_vars.standalone
and koboldai_vars.lua_koboldbridge.generated_cols
and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols
):
print("[TODO] Fix generated_cols")
# raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
if koboldai_vars.abort:
koboldai_vars.abort = False
self.regeneration_required = False
self.halt = False
return True
if koboldai_vars.standalone:
return False
assert input_ids.ndim == 2
self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
self.halt = not koboldai_vars.lua_koboldbridge.generating
koboldai_vars.lua_koboldbridge.regeneration_required = False
for i in range(koboldai_vars.numseqs):
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
return self.regeneration_required or self.halt
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
@@ -2024,6 +2067,7 @@ def patch_transformers():
self.halt = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
def __call__(
self,
input_ids: torch.LongTensor,
@@ -2034,35 +2078,38 @@ def patch_transformers():
if not koboldai_vars.inference_config.do_dynamic_wi:
return False
koboldai_vars.generated_tkns += 1
if(not koboldai_vars.standalone and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
self.regeneration_required = False
self.halt = False
koboldai_vars.abort = False
return True
if(koboldai_vars.standalone):
# if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
# self.regeneration_required = False
# self.halt = False
# koboldai_vars.abort = False
# return True
# Pertains to WI I think
# if(koboldai_vars.standalone):
# return False
# assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
# self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
# self.halt = not koboldai_vars.lua_koboldbridge.generating
# koboldai_vars.lua_koboldbridge.regeneration_required = False
# for i in range(koboldai_vars.numseqs):
# koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
if not koboldai_vars.dynamicscan:
#return self.regeneration_required or self.halt
return False
assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
self.halt = not koboldai_vars.lua_koboldbridge.generating
koboldai_vars.lua_koboldbridge.regeneration_required = False
for i in range(koboldai_vars.numseqs):
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
if(not koboldai_vars.dynamicscan):
return self.regeneration_required or self.halt
tail = input_ids[..., -koboldai_vars.generated_tkns:]
for i, t in enumerate(tail):
decoded = utils.decodenewlines(tokenizer.decode(t))
_, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
found -= self.excluded_world_info[i]
if(len(found) != 0):
self.regeneration_required = True
if len(found) != 0:
# self.regeneration_required = True
model.core_stopper.regeneration_required = True
return True
break
return self.regeneration_required or self.halt
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
@@ -2070,12 +2117,14 @@ def patch_transformers():
global tokenizer
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
self.core_stopper = CoreStopper()
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
)
token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, self.core_stopper)
stopping_criteria.insert(0, self.kai_scanner)
token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, token_streamer)
@@ -4811,28 +4860,36 @@ def calcsubmit(txt):
def legacy_generate(text: Union[str, list], min: int, max: int):
koboldai_vars.lastctx = text
outputs = tpool.execute(
out_batches = tpool.execute(
raw_generate,
text,
max_length=koboldai_vars.genamt,
do_streaming=True
do_streaming=True,
batch_count=koboldai_vars.numseqs,
decode=False
)
decoded_batches = tokenizer.batch_decode(out_batches)
# Lua bridge, genmod
for i, output in enumerate(outputs):
koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
for i in range(koboldai_vars.numseqs):
koboldai_vars.lua_koboldbridge.generated[i + 1][koboldai_vars.generated_tkns] = int(out_batches[i, -1].item())
koboldai_vars.lua_koboldbridge.outputs[i + 1] = utils.decodenewlines(tokenizer.decode(out_batches[i, -len(out_batches[i]):]))
# for i, output in enumerate(outputs):
# koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
execute_genmod()
if koboldai_vars.lua_koboldbridge.regeneration_required:
koboldai_vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(len(outputs)):
for i in range(len(out_batches)):
out = koboldai_vars.lua_koboldbridge.outputs[i + 1]
genout.append({"generated_text": out})
assert isinstance(out, str)
else:
genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs]
genout = [{"generated_text": utils.decodenewlines(x)} for x in decoded_batches]
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
@@ -4855,12 +4912,17 @@ def raw_generate(
do_streaming: bool = False,
do_dynamic_wi: bool = False,
batch_count: int = 1,
decode: bool = True,
) -> List:
if isinstance(prompt, str):
prompt_tokens = tokenizer.encode(prompt)
else:
prompt_tokens = prompt
# Some gen methods such as OAI don't return tokens.
batch_decoded = None
batch_encoded = None
if koboldai_vars.model == "Colab":
raise NotImplementedError("Colab API raw_generate unsupported")
@@ -4874,13 +4936,15 @@ def raw_generate(
raise NotImplementedError("No loaded model")
if koboldai_vars.use_colab_tpu or model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
batch_out = tpu_raw_generate(
batch_encoded = tpu_raw_generate(
prompt_tokens=prompt_tokens,
max_length=max_length,
batch_count=batch_count
)
elif model == "OAI":
batch_decoded = ...
else:
batch_out = torch_raw_generate(
batch_encoded = torch_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_length,
do_streaming=do_streaming,
@@ -4888,9 +4952,21 @@ def raw_generate(
batch_count=batch_count
)
decoded = tokenizer.batch_decode(batch_out[:, len(prompt_tokens):])
return [utils.decodenewlines(x) for x in decoded]
assert batch_encoded or batch_decoded
# Shave prompt off of encoded response. Decoded does not return prompt.
# TODO: Does MTJ generation shave this off automatically? Test it!
if batch_encoded:
batch_encoded = batch_encoded[:, len(prompt_tokens):]
if not decode:
return batch_encoded
if not batch_decoded:
batch_decoded = tokenizer.batch_decode(batch_encoded)
return [utils.decodenewlines(x) for x in batch_decoded]
def tpu_raw_generate(
prompt_tokens: List[int],
@@ -4961,7 +5037,7 @@ def torch_raw_generate(
# Send text to generator and deal with output
#==================================================================#
def _generate(txt, minimum, maximum, found_entries):
def old_underscore_generate(txt, minimum, maximum, found_entries):
if(koboldai_vars.full_determinism):
torch.manual_seed(koboldai_vars.seed)
@@ -5000,19 +5076,30 @@ def _generate(txt, minimum, maximum, found_entries):
)
already_generated += len(genout[0]) - len(gen_in[0])
assert already_generated <= koboldai_vars.genamt
# If we are halting, we stop
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
break
# if we require a generation, we continue
assert genout.ndim >= 2
assert genout.shape[0] == koboldai_vars.numseqs
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
if(already_generated != koboldai_vars.generated_tkns):
raise RuntimeError("WI scanning error")
for r in range(koboldai_vars.numseqs):
for c in range(already_generated):
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
encoded = []
# DYNAMIC WI:
# IF WE FIND WORLD INFO MID-GENERATION, STOP, THEN ADD WI AND ADD NEW GENERATION
for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
@@ -5023,6 +5110,7 @@ def _generate(txt, minimum, maximum, found_entries):
else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
@@ -5032,6 +5120,7 @@ def _generate(txt, minimum, maximum, found_entries):
),
dim=-1
)
if(koboldai_vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
@@ -5049,7 +5138,7 @@ def _generate(txt, minimum, maximum, found_entries):
return genout, already_generated
def generate(txt, minimum, maximum, found_entries=None):
def old_generate(txt, minimum, maximum, found_entries=None):
koboldai_vars.generated_tkns = 0
if(found_entries is None):