Actually just steal from _generate

Why I didn't do this before? It's a mystery!
This commit is contained in:
somebody
2022-09-19 21:06:14 -05:00
parent 2ed7a9dcee
commit 4d3a80e4a6
2 changed files with 136 additions and 182 deletions

View File

@@ -2032,10 +2032,16 @@ def patch_transformers():
and koboldai_vars.lua_koboldbridge.generated_cols
and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols
):
print("[TODO] Fix generated_cols")
# raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
if koboldai_vars.abort:
if (
koboldai_vars.abort
or (
koboldai_vars.inference_config.stop_at_genamt
and
koboldai_vars.generated_tkns >= koboldai_vars.genamt
)
):
koboldai_vars.abort = False
self.regeneration_required = False
self.halt = False
@@ -2063,8 +2069,8 @@ def patch_transformers():
tokenizer,
excluded_world_info: List[Set],
):
self.regeneration_required = False
self.halt = False
# self.regeneration_required = False
# self.halt = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
@@ -2078,27 +2084,9 @@ def patch_transformers():
if not koboldai_vars.inference_config.do_dynamic_wi:
return False
# if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
# self.regeneration_required = False
# self.halt = False
# koboldai_vars.abort = False
# return True
# Pertains to WI I think
# if(koboldai_vars.standalone):
# return False
# assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
# self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
# self.halt = not koboldai_vars.lua_koboldbridge.generating
# koboldai_vars.lua_koboldbridge.regeneration_required = False
# for i in range(koboldai_vars.numseqs):
# koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
if not koboldai_vars.dynamicscan:
#return self.regeneration_required or self.halt
return False
tail = input_ids[..., -koboldai_vars.generated_tkns:]
@@ -2107,11 +2095,10 @@ def patch_transformers():
_, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
found -= self.excluded_world_info[i]
if len(found) != 0:
# self.regeneration_required = True
model.core_stopper.regeneration_required = True
return True
break
return self.regeneration_required or self.halt
return False
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
global tokenizer
@@ -4789,7 +4776,7 @@ def calcsubmit(txt):
"TPUMeshTransformerGPTJ",
"TPUMeshTransformerGPTNeoX"
):
legacy_generate(subtxt, min, max)
generate(subtxt, min, max, found_entries)
elif koboldai_vars.model == "Colab":
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif koboldai_vars.model == "API":
@@ -4857,52 +4844,117 @@ def calcsubmit(txt):
# Send it!
ikrequest(subtxt)
def legacy_generate(text: Union[str, list], min: int, max: int):
koboldai_vars.lastctx = text
def core_generate(text: list, min: int, max: int, found_entries: set):
# This generation function is tangled with koboldai_vars intentionally. It
# is meant for the story and nothing else.
out_batches = tpool.execute(
raw_generate,
text,
max_length=koboldai_vars.genamt,
do_streaming=True,
batch_count=koboldai_vars.numseqs,
decode=False
)
if koboldai_vars.full_determinism:
torch.manual_seed(koboldai_vars.seed)
decoded_batches = tokenizer.batch_decode(out_batches)
gen_in = torch.tensor(text, dtype=torch.long)[None]
if koboldai_vars.sp is not None:
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
# Lua bridge, genmod
for i in range(koboldai_vars.numseqs):
koboldai_vars.lua_koboldbridge.generated[i + 1][koboldai_vars.generated_tkns] = int(out_batches[i, -1].item())
koboldai_vars.lua_koboldbridge.outputs[i + 1] = utils.decodenewlines(tokenizer.decode(out_batches[i, -len(out_batches[i]):]))
assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length
# for i, output in enumerate(outputs):
# koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
execute_genmod()
if koboldai_vars.lua_koboldbridge.regeneration_required:
koboldai_vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(len(out_batches)):
out = koboldai_vars.lua_koboldbridge.outputs[i + 1]
genout.append({"generated_text": out})
assert isinstance(out, str)
if koboldai_vars.hascuda and koboldai_vars.usegpu:
gen_in = gen_in.to(koboldai_vars.gpu_device)
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
gen_in = gen_in.to(breakmodel.primary_device)
else:
genout = [{"generated_text": utils.decodenewlines(x)} for x in decoded_batches]
gen_in = gen_in.to("cpu")
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
found_entries = found_entries or set()
model.kai_scanner_excluded_world_info = found_entries
if len(genout) == 1:
genresult(genout[0]["generated_text"])
else:
restart_seq = koboldai_vars.lua_koboldbridge.restart_sequence
if restart_seq and restart_seq > 0:
genresult(genout[restart_seq - 1]["generated_text"])
else:
genselect(genout)
set_aibusy(0)
koboldai_vars._prompt = koboldai_vars.prompt
with torch.no_grad():
already_generated = 0
numseqs = koboldai_vars.numseqs
while True:
# The reason this is a loop is due to how Dynamic WI works. We
# cannot simply add the WI to the context mid-generation, so we
# stop early, and then insert WI, then continue generating. That
# stopping and continuing is this loop.
genout = raw_generate(
gen_in,
# Real max length is handled by CoreStopper.
max_length=int(2e9),
do_streaming=True,
do_dynamic_wi=True,
batch_count=numseqs,
decode=False,
)
already_generated += len(genout[0]) - len(gen_in[0])
assert already_generated <= koboldai_vars.genamt
# Generation stopped; why?
# If we have been told to halt, we have reached our target token
# amount (controlled by halt), or Dynamic WI has not told us to
# stop temporarily to insert WI, we can assume that we are done
# generating. We shall break.
if model.core_stopper.halt or not model.core_stopper.regeneration_required:
break
# Now we are doing stuff for Dynamic WI.
assert genout.ndim >= 2
assert genout.shape[0] == koboldai_vars.numseqs
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
if(already_generated != koboldai_vars.generated_tkns):
raise RuntimeError("WI scanning error")
for r in range(koboldai_vars.numseqs):
for c in range(already_generated):
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
encoded = []
for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
found_entries[i].update(_found_entries)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen")
else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(koboldai_vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
diff = genout.shape[-1] - gen_in.shape[-1]
minimum += diff
maximum += diff
gen_in = genout
numseqs = 1
return genout, already_generated
def raw_generate(
# prompt is either a string (text) or a list (token ids)
@@ -4952,18 +5004,17 @@ def raw_generate(
batch_count=batch_count
)
assert batch_encoded or batch_decoded
assert batch_encoded is not None or batch_decoded is not None
# Shave prompt off of encoded response. Decoded does not return prompt.
# TODO: Does MTJ generation shave this off automatically? Test it!
if batch_encoded:
batch_encoded = batch_encoded[:, len(prompt_tokens):]
if batch_encoded is not None:
batch_encoded = batch_encoded[:, len(prompt_tokens) - 1:]
if not decode:
return batch_encoded
if not batch_decoded:
if batch_decoded is None:
batch_decoded = tokenizer.batch_decode(batch_encoded)
return [utils.decodenewlines(x) for x in batch_decoded]
@@ -4997,7 +5048,7 @@ def tpu_raw_generate(
return genout
def torch_raw_generate(
prompt_tokens: List[int],
prompt_tokens: Union[List[int], torch.Tensor],
max_new: int,
do_streaming: bool = False,
@@ -5008,10 +5059,16 @@ def torch_raw_generate(
koboldai_vars.inference_config.do_streaming = do_streaming
koboldai_vars.inference_config.do_dynamic_wi = do_dynamic_wi
# Makes stopping criteria hook happy
model.kai_scanner_excluded_world_info = []
# Dynamic WI depends on this!!! This is a main gen call.
koboldai_vars.inference_config.stop_at_genamt = do_dynamic_wi
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
# Makes stopping criteria hook happy
model.kai_scanner_excluded_world_info = model.kai_scanner_excluded_world_info or set()
if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else:
gen_in = prompt_tokens
device = "cpu"
if koboldai_vars.hascuda and koboldai_vars.usegpu:
@@ -5037,108 +5094,7 @@ def torch_raw_generate(
# Send text to generator and deal with output
#==================================================================#
def old_underscore_generate(txt, minimum, maximum, found_entries):
if(koboldai_vars.full_determinism):
torch.manual_seed(koboldai_vars.seed)
gen_in = torch.tensor(txt, dtype=torch.long)[None]
if(koboldai_vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length
if(koboldai_vars.hascuda and koboldai_vars.usegpu):
gen_in = gen_in.to(koboldai_vars.gpu_device)
elif(koboldai_vars.hascuda and koboldai_vars.breakmodel):
gen_in = gen_in.to(breakmodel.primary_device)
else:
gen_in = gen_in.to('cpu')
model.kai_scanner_excluded_world_info = found_entries
koboldai_vars._prompt = koboldai_vars.prompt
with torch.no_grad():
already_generated = 0
numseqs = koboldai_vars.numseqs
while True:
genout = generator(
gen_in,
do_sample=True,
max_length=int(2e9),
repetition_penalty=1.0,
bad_words_ids=koboldai_vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
assert already_generated <= koboldai_vars.genamt
# If we are halting, we stop
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
break
# if we require a generation, we continue
assert genout.ndim >= 2
assert genout.shape[0] == koboldai_vars.numseqs
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
if(already_generated != koboldai_vars.generated_tkns):
raise RuntimeError("WI scanning error")
for r in range(koboldai_vars.numseqs):
for c in range(already_generated):
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
encoded = []
# DYNAMIC WI:
# IF WE FIND WORLD INFO MID-GENERATION, STOP, THEN ADD WI AND ADD NEW GENERATION
for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
found_entries[i].update(_found_entries)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen")
else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(koboldai_vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + koboldai_vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
diff = genout.shape[-1] - gen_in.shape[-1]
minimum += diff
maximum += diff
gen_in = genout
numseqs = 1
return genout, already_generated
def old_generate(txt, minimum, maximum, found_entries=None):
def generate(txt, minimum, maximum, found_entries=None):
koboldai_vars.generated_tkns = 0
if(found_entries is None):
@@ -5158,7 +5114,7 @@ def old_generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator
try:
genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, found_entries)
genout, already_generated = tpool.execute(core_generate, txt, minimum, maximum, found_entries)
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
koboldai_vars.lua_koboldbridge.obliterate_multiverse()

View File

@@ -800,11 +800,9 @@ class system_settings(settings):
@dataclass
class _inference_config:
do_streaming: bool = False
# NOTE: DynamicWorldInfoScanCriteria handles not only dynamic world
# info, but also max length, aborting, regeneration requests, etc
# for kobold-rooted stuff. This would be nice to change in the future.
do_dynamic_wi: bool = False
# Genamt stopping is mostly tied to Dynamic WI
stop_at_genamt: bool = False
self.inference_config = _inference_config()