mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Actually just steal from _generate
Why I didn't do this before? It's a mystery!
This commit is contained in:
312
aiserver.py
312
aiserver.py
@@ -2032,10 +2032,16 @@ def patch_transformers():
|
||||
and koboldai_vars.lua_koboldbridge.generated_cols
|
||||
and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols
|
||||
):
|
||||
print("[TODO] Fix generated_cols")
|
||||
# raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
|
||||
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
|
||||
|
||||
if koboldai_vars.abort:
|
||||
if (
|
||||
koboldai_vars.abort
|
||||
or (
|
||||
koboldai_vars.inference_config.stop_at_genamt
|
||||
and
|
||||
koboldai_vars.generated_tkns >= koboldai_vars.genamt
|
||||
)
|
||||
):
|
||||
koboldai_vars.abort = False
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
@@ -2063,8 +2069,8 @@ def patch_transformers():
|
||||
tokenizer,
|
||||
excluded_world_info: List[Set],
|
||||
):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
# self.regeneration_required = False
|
||||
# self.halt = False
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
|
||||
@@ -2078,27 +2084,9 @@ def patch_transformers():
|
||||
if not koboldai_vars.inference_config.do_dynamic_wi:
|
||||
return False
|
||||
|
||||
# if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
|
||||
# self.regeneration_required = False
|
||||
# self.halt = False
|
||||
# koboldai_vars.abort = False
|
||||
# return True
|
||||
|
||||
# Pertains to WI I think
|
||||
# if(koboldai_vars.standalone):
|
||||
# return False
|
||||
|
||||
# assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
# self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
|
||||
# self.halt = not koboldai_vars.lua_koboldbridge.generating
|
||||
# koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
# for i in range(koboldai_vars.numseqs):
|
||||
# koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||
|
||||
if not koboldai_vars.dynamicscan:
|
||||
#return self.regeneration_required or self.halt
|
||||
return False
|
||||
|
||||
tail = input_ids[..., -koboldai_vars.generated_tkns:]
|
||||
@@ -2107,11 +2095,10 @@ def patch_transformers():
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
|
||||
found -= self.excluded_world_info[i]
|
||||
if len(found) != 0:
|
||||
# self.regeneration_required = True
|
||||
model.core_stopper.regeneration_required = True
|
||||
return True
|
||||
break
|
||||
return self.regeneration_required or self.halt
|
||||
return False
|
||||
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
global tokenizer
|
||||
@@ -4789,7 +4776,7 @@ def calcsubmit(txt):
|
||||
"TPUMeshTransformerGPTJ",
|
||||
"TPUMeshTransformerGPTNeoX"
|
||||
):
|
||||
legacy_generate(subtxt, min, max)
|
||||
generate(subtxt, min, max, found_entries)
|
||||
elif koboldai_vars.model == "Colab":
|
||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif koboldai_vars.model == "API":
|
||||
@@ -4857,52 +4844,117 @@ def calcsubmit(txt):
|
||||
# Send it!
|
||||
ikrequest(subtxt)
|
||||
|
||||
def legacy_generate(text: Union[str, list], min: int, max: int):
|
||||
koboldai_vars.lastctx = text
|
||||
def core_generate(text: list, min: int, max: int, found_entries: set):
|
||||
# This generation function is tangled with koboldai_vars intentionally. It
|
||||
# is meant for the story and nothing else.
|
||||
|
||||
out_batches = tpool.execute(
|
||||
raw_generate,
|
||||
text,
|
||||
max_length=koboldai_vars.genamt,
|
||||
do_streaming=True,
|
||||
batch_count=koboldai_vars.numseqs,
|
||||
decode=False
|
||||
)
|
||||
if koboldai_vars.full_determinism:
|
||||
torch.manual_seed(koboldai_vars.seed)
|
||||
|
||||
decoded_batches = tokenizer.batch_decode(out_batches)
|
||||
gen_in = torch.tensor(text, dtype=torch.long)[None]
|
||||
if koboldai_vars.sp is not None:
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + koboldai_vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
|
||||
# Lua bridge, genmod
|
||||
for i in range(koboldai_vars.numseqs):
|
||||
koboldai_vars.lua_koboldbridge.generated[i + 1][koboldai_vars.generated_tkns] = int(out_batches[i, -1].item())
|
||||
koboldai_vars.lua_koboldbridge.outputs[i + 1] = utils.decodenewlines(tokenizer.decode(out_batches[i, -len(out_batches[i]):]))
|
||||
assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length
|
||||
|
||||
# for i, output in enumerate(outputs):
|
||||
# koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
||||
|
||||
execute_genmod()
|
||||
|
||||
if koboldai_vars.lua_koboldbridge.regeneration_required:
|
||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(len(out_batches)):
|
||||
out = koboldai_vars.lua_koboldbridge.outputs[i + 1]
|
||||
genout.append({"generated_text": out})
|
||||
assert isinstance(out, str)
|
||||
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
||||
gen_in = gen_in.to(koboldai_vars.gpu_device)
|
||||
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
|
||||
gen_in = gen_in.to(breakmodel.primary_device)
|
||||
else:
|
||||
genout = [{"generated_text": utils.decodenewlines(x)} for x in decoded_batches]
|
||||
gen_in = gen_in.to("cpu")
|
||||
|
||||
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||
found_entries = found_entries or set()
|
||||
model.kai_scanner_excluded_world_info = found_entries
|
||||
|
||||
if len(genout) == 1:
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
restart_seq = koboldai_vars.lua_koboldbridge.restart_sequence
|
||||
if restart_seq and restart_seq > 0:
|
||||
genresult(genout[restart_seq - 1]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
set_aibusy(0)
|
||||
koboldai_vars._prompt = koboldai_vars.prompt
|
||||
|
||||
with torch.no_grad():
|
||||
already_generated = 0
|
||||
numseqs = koboldai_vars.numseqs
|
||||
while True:
|
||||
# The reason this is a loop is due to how Dynamic WI works. We
|
||||
# cannot simply add the WI to the context mid-generation, so we
|
||||
# stop early, and then insert WI, then continue generating. That
|
||||
# stopping and continuing is this loop.
|
||||
|
||||
genout = raw_generate(
|
||||
gen_in,
|
||||
# Real max length is handled by CoreStopper.
|
||||
max_length=int(2e9),
|
||||
do_streaming=True,
|
||||
do_dynamic_wi=True,
|
||||
batch_count=numseqs,
|
||||
decode=False,
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
assert already_generated <= koboldai_vars.genamt
|
||||
|
||||
# Generation stopped; why?
|
||||
# If we have been told to halt, we have reached our target token
|
||||
# amount (controlled by halt), or Dynamic WI has not told us to
|
||||
# stop temporarily to insert WI, we can assume that we are done
|
||||
# generating. We shall break.
|
||||
if model.core_stopper.halt or not model.core_stopper.regeneration_required:
|
||||
break
|
||||
|
||||
# Now we are doing stuff for Dynamic WI.
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == koboldai_vars.numseqs
|
||||
|
||||
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
|
||||
|
||||
if(already_generated != koboldai_vars.generated_tkns):
|
||||
raise RuntimeError("WI scanning error")
|
||||
|
||||
for r in range(koboldai_vars.numseqs):
|
||||
for c in range(already_generated):
|
||||
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
|
||||
|
||||
encoded = []
|
||||
|
||||
for i in range(koboldai_vars.numseqs):
|
||||
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
if koboldai_vars.alt_gen:
|
||||
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
|
||||
print("Using Alt Gen")
|
||||
else:
|
||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
genout = torch.cat(
|
||||
(
|
||||
encoded,
|
||||
genout[..., -already_generated:],
|
||||
),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if(koboldai_vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + koboldai_vars.sp.shape[0],
|
||||
device=genout.device,
|
||||
)
|
||||
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
|
||||
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
|
||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||
minimum += diff
|
||||
maximum += diff
|
||||
gen_in = genout
|
||||
numseqs = 1
|
||||
|
||||
return genout, already_generated
|
||||
|
||||
def raw_generate(
|
||||
# prompt is either a string (text) or a list (token ids)
|
||||
@@ -4952,18 +5004,17 @@ def raw_generate(
|
||||
batch_count=batch_count
|
||||
)
|
||||
|
||||
|
||||
assert batch_encoded or batch_decoded
|
||||
assert batch_encoded is not None or batch_decoded is not None
|
||||
|
||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||
if batch_encoded:
|
||||
batch_encoded = batch_encoded[:, len(prompt_tokens):]
|
||||
if batch_encoded is not None:
|
||||
batch_encoded = batch_encoded[:, len(prompt_tokens) - 1:]
|
||||
|
||||
if not decode:
|
||||
return batch_encoded
|
||||
|
||||
if not batch_decoded:
|
||||
if batch_decoded is None:
|
||||
batch_decoded = tokenizer.batch_decode(batch_encoded)
|
||||
|
||||
return [utils.decodenewlines(x) for x in batch_decoded]
|
||||
@@ -4997,7 +5048,7 @@ def tpu_raw_generate(
|
||||
return genout
|
||||
|
||||
def torch_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
prompt_tokens: Union[List[int], torch.Tensor],
|
||||
max_new: int,
|
||||
|
||||
do_streaming: bool = False,
|
||||
@@ -5008,10 +5059,16 @@ def torch_raw_generate(
|
||||
koboldai_vars.inference_config.do_streaming = do_streaming
|
||||
koboldai_vars.inference_config.do_dynamic_wi = do_dynamic_wi
|
||||
|
||||
# Makes stopping criteria hook happy
|
||||
model.kai_scanner_excluded_world_info = []
|
||||
# Dynamic WI depends on this!!! This is a main gen call.
|
||||
koboldai_vars.inference_config.stop_at_genamt = do_dynamic_wi
|
||||
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
# Makes stopping criteria hook happy
|
||||
model.kai_scanner_excluded_world_info = model.kai_scanner_excluded_world_info or set()
|
||||
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
else:
|
||||
gen_in = prompt_tokens
|
||||
|
||||
device = "cpu"
|
||||
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
||||
@@ -5037,108 +5094,7 @@ def torch_raw_generate(
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
|
||||
def old_underscore_generate(txt, minimum, maximum, found_entries):
|
||||
if(koboldai_vars.full_determinism):
|
||||
torch.manual_seed(koboldai_vars.seed)
|
||||
|
||||
gen_in = torch.tensor(txt, dtype=torch.long)[None]
|
||||
if(koboldai_vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + koboldai_vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length
|
||||
|
||||
if(koboldai_vars.hascuda and koboldai_vars.usegpu):
|
||||
gen_in = gen_in.to(koboldai_vars.gpu_device)
|
||||
elif(koboldai_vars.hascuda and koboldai_vars.breakmodel):
|
||||
gen_in = gen_in.to(breakmodel.primary_device)
|
||||
else:
|
||||
gen_in = gen_in.to('cpu')
|
||||
|
||||
model.kai_scanner_excluded_world_info = found_entries
|
||||
|
||||
koboldai_vars._prompt = koboldai_vars.prompt
|
||||
|
||||
with torch.no_grad():
|
||||
already_generated = 0
|
||||
numseqs = koboldai_vars.numseqs
|
||||
while True:
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
max_length=int(2e9),
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=koboldai_vars.badwordsids,
|
||||
use_cache=True,
|
||||
num_return_sequences=numseqs
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
assert already_generated <= koboldai_vars.genamt
|
||||
# If we are halting, we stop
|
||||
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
||||
break
|
||||
# if we require a generation, we continue
|
||||
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == koboldai_vars.numseqs
|
||||
|
||||
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
|
||||
|
||||
if(already_generated != koboldai_vars.generated_tkns):
|
||||
raise RuntimeError("WI scanning error")
|
||||
|
||||
for r in range(koboldai_vars.numseqs):
|
||||
for c in range(already_generated):
|
||||
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
|
||||
|
||||
encoded = []
|
||||
|
||||
# DYNAMIC WI:
|
||||
# IF WE FIND WORLD INFO MID-GENERATION, STOP, THEN ADD WI AND ADD NEW GENERATION
|
||||
|
||||
for i in range(koboldai_vars.numseqs):
|
||||
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
if koboldai_vars.alt_gen:
|
||||
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
|
||||
print("Using Alt Gen")
|
||||
else:
|
||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
genout = torch.cat(
|
||||
(
|
||||
encoded,
|
||||
genout[..., -already_generated:],
|
||||
),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if(koboldai_vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + koboldai_vars.sp.shape[0],
|
||||
device=genout.device,
|
||||
)
|
||||
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
|
||||
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
|
||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||
minimum += diff
|
||||
maximum += diff
|
||||
gen_in = genout
|
||||
numseqs = 1
|
||||
|
||||
return genout, already_generated
|
||||
|
||||
|
||||
def old_generate(txt, minimum, maximum, found_entries=None):
|
||||
def generate(txt, minimum, maximum, found_entries=None):
|
||||
koboldai_vars.generated_tkns = 0
|
||||
|
||||
if(found_entries is None):
|
||||
@@ -5158,7 +5114,7 @@ def old_generate(txt, minimum, maximum, found_entries=None):
|
||||
|
||||
# Submit input text to generator
|
||||
try:
|
||||
genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, found_entries)
|
||||
genout, already_generated = tpool.execute(core_generate, txt, minimum, maximum, found_entries)
|
||||
except Exception as e:
|
||||
if(issubclass(type(e), lupa.LuaError)):
|
||||
koboldai_vars.lua_koboldbridge.obliterate_multiverse()
|
||||
|
@@ -800,11 +800,9 @@ class system_settings(settings):
|
||||
@dataclass
|
||||
class _inference_config:
|
||||
do_streaming: bool = False
|
||||
|
||||
# NOTE: DynamicWorldInfoScanCriteria handles not only dynamic world
|
||||
# info, but also max length, aborting, regeneration requests, etc
|
||||
# for kobold-rooted stuff. This would be nice to change in the future.
|
||||
do_dynamic_wi: bool = False
|
||||
# Genamt stopping is mostly tied to Dynamic WI
|
||||
stop_at_genamt: bool = False
|
||||
self.inference_config = _inference_config()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user