mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Eternal gen work
This commit is contained in:
157
aiserver.py
157
aiserver.py
@@ -2013,6 +2013,49 @@ def patch_transformers():
|
|||||||
koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
|
koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class CoreStopper(StoppingCriteria):
|
||||||
|
# Controls core generation stuff; aborting, counting generated tokens, etc
|
||||||
|
def __init__(self):
|
||||||
|
self.regeneration_required = False
|
||||||
|
self.halt = False
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> bool:
|
||||||
|
koboldai_vars.generated_tkns += 1
|
||||||
|
|
||||||
|
if (
|
||||||
|
not koboldai_vars.standalone
|
||||||
|
and koboldai_vars.lua_koboldbridge.generated_cols
|
||||||
|
and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols
|
||||||
|
):
|
||||||
|
print("[TODO] Fix generated_cols")
|
||||||
|
# raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
|
||||||
|
|
||||||
|
if koboldai_vars.abort:
|
||||||
|
koboldai_vars.abort = False
|
||||||
|
self.regeneration_required = False
|
||||||
|
self.halt = False
|
||||||
|
return True
|
||||||
|
|
||||||
|
if koboldai_vars.standalone:
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
|
self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
|
||||||
|
self.halt = not koboldai_vars.lua_koboldbridge.generating
|
||||||
|
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
|
for i in range(koboldai_vars.numseqs):
|
||||||
|
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||||
|
|
||||||
|
return self.regeneration_required or self.halt
|
||||||
|
|
||||||
|
|
||||||
# Sets up dynamic world info scanner
|
# Sets up dynamic world info scanner
|
||||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -2024,6 +2067,7 @@ def patch_transformers():
|
|||||||
self.halt = False
|
self.halt = False
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.excluded_world_info = excluded_world_info
|
self.excluded_world_info = excluded_world_info
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
@@ -2034,35 +2078,38 @@ def patch_transformers():
|
|||||||
if not koboldai_vars.inference_config.do_dynamic_wi:
|
if not koboldai_vars.inference_config.do_dynamic_wi:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
koboldai_vars.generated_tkns += 1
|
# if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
|
||||||
if(not koboldai_vars.standalone and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
# self.regeneration_required = False
|
||||||
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})")
|
# self.halt = False
|
||||||
if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt):
|
# koboldai_vars.abort = False
|
||||||
self.regeneration_required = False
|
# return True
|
||||||
self.halt = False
|
|
||||||
koboldai_vars.abort = False
|
# Pertains to WI I think
|
||||||
return True
|
# if(koboldai_vars.standalone):
|
||||||
if(koboldai_vars.standalone):
|
# return False
|
||||||
|
|
||||||
|
# assert input_ids.ndim == 2
|
||||||
|
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||||
|
# self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
|
||||||
|
# self.halt = not koboldai_vars.lua_koboldbridge.generating
|
||||||
|
# koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
|
# for i in range(koboldai_vars.numseqs):
|
||||||
|
# koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||||
|
|
||||||
|
if not koboldai_vars.dynamicscan:
|
||||||
|
#return self.regeneration_required or self.halt
|
||||||
return False
|
return False
|
||||||
|
|
||||||
assert input_ids.ndim == 2
|
|
||||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
|
||||||
self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required
|
|
||||||
self.halt = not koboldai_vars.lua_koboldbridge.generating
|
|
||||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
|
||||||
|
|
||||||
for i in range(koboldai_vars.numseqs):
|
|
||||||
koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item())
|
|
||||||
|
|
||||||
if(not koboldai_vars.dynamicscan):
|
|
||||||
return self.regeneration_required or self.halt
|
|
||||||
tail = input_ids[..., -koboldai_vars.generated_tkns:]
|
tail = input_ids[..., -koboldai_vars.generated_tkns:]
|
||||||
for i, t in enumerate(tail):
|
for i, t in enumerate(tail):
|
||||||
decoded = utils.decodenewlines(tokenizer.decode(t))
|
decoded = utils.decodenewlines(tokenizer.decode(t))
|
||||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
|
_, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
|
||||||
found -= self.excluded_world_info[i]
|
found -= self.excluded_world_info[i]
|
||||||
if(len(found) != 0):
|
if len(found) != 0:
|
||||||
self.regeneration_required = True
|
# self.regeneration_required = True
|
||||||
|
model.core_stopper.regeneration_required = True
|
||||||
|
return True
|
||||||
break
|
break
|
||||||
return self.regeneration_required or self.halt
|
return self.regeneration_required or self.halt
|
||||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||||
@@ -2070,12 +2117,14 @@ def patch_transformers():
|
|||||||
global tokenizer
|
global tokenizer
|
||||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||||
|
|
||||||
|
self.core_stopper = CoreStopper()
|
||||||
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||||
)
|
)
|
||||||
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||||
|
|
||||||
|
stopping_criteria.insert(0, self.core_stopper)
|
||||||
stopping_criteria.insert(0, self.kai_scanner)
|
stopping_criteria.insert(0, self.kai_scanner)
|
||||||
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||||
stopping_criteria.insert(0, token_streamer)
|
stopping_criteria.insert(0, token_streamer)
|
||||||
@@ -4811,28 +4860,36 @@ def calcsubmit(txt):
|
|||||||
def legacy_generate(text: Union[str, list], min: int, max: int):
|
def legacy_generate(text: Union[str, list], min: int, max: int):
|
||||||
koboldai_vars.lastctx = text
|
koboldai_vars.lastctx = text
|
||||||
|
|
||||||
outputs = tpool.execute(
|
out_batches = tpool.execute(
|
||||||
raw_generate,
|
raw_generate,
|
||||||
text,
|
text,
|
||||||
max_length=koboldai_vars.genamt,
|
max_length=koboldai_vars.genamt,
|
||||||
do_streaming=True
|
do_streaming=True,
|
||||||
|
batch_count=koboldai_vars.numseqs,
|
||||||
|
decode=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
decoded_batches = tokenizer.batch_decode(out_batches)
|
||||||
|
|
||||||
# Lua bridge, genmod
|
# Lua bridge, genmod
|
||||||
for i, output in enumerate(outputs):
|
for i in range(koboldai_vars.numseqs):
|
||||||
koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
koboldai_vars.lua_koboldbridge.generated[i + 1][koboldai_vars.generated_tkns] = int(out_batches[i, -1].item())
|
||||||
|
koboldai_vars.lua_koboldbridge.outputs[i + 1] = utils.decodenewlines(tokenizer.decode(out_batches[i, -len(out_batches[i]):]))
|
||||||
|
|
||||||
|
# for i, output in enumerate(outputs):
|
||||||
|
# koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
||||||
|
|
||||||
execute_genmod()
|
execute_genmod()
|
||||||
|
|
||||||
if koboldai_vars.lua_koboldbridge.regeneration_required:
|
if koboldai_vars.lua_koboldbridge.regeneration_required:
|
||||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||||
genout = []
|
genout = []
|
||||||
for i in range(len(outputs)):
|
for i in range(len(out_batches)):
|
||||||
out = koboldai_vars.lua_koboldbridge.outputs[i + 1]
|
out = koboldai_vars.lua_koboldbridge.outputs[i + 1]
|
||||||
genout.append({"generated_text": out})
|
genout.append({"generated_text": out})
|
||||||
assert isinstance(out, str)
|
assert isinstance(out, str)
|
||||||
else:
|
else:
|
||||||
genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs]
|
genout = [{"generated_text": utils.decodenewlines(x)} for x in decoded_batches]
|
||||||
|
|
||||||
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
||||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||||
@@ -4855,6 +4912,7 @@ def raw_generate(
|
|||||||
do_streaming: bool = False,
|
do_streaming: bool = False,
|
||||||
do_dynamic_wi: bool = False,
|
do_dynamic_wi: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
decode: bool = True,
|
||||||
) -> List:
|
) -> List:
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
@@ -4862,6 +4920,10 @@ def raw_generate(
|
|||||||
else:
|
else:
|
||||||
prompt_tokens = prompt
|
prompt_tokens = prompt
|
||||||
|
|
||||||
|
# Some gen methods such as OAI don't return tokens.
|
||||||
|
batch_decoded = None
|
||||||
|
batch_encoded = None
|
||||||
|
|
||||||
if koboldai_vars.model == "Colab":
|
if koboldai_vars.model == "Colab":
|
||||||
raise NotImplementedError("Colab API raw_generate unsupported")
|
raise NotImplementedError("Colab API raw_generate unsupported")
|
||||||
elif koboldai_vars.model == "API":
|
elif koboldai_vars.model == "API":
|
||||||
@@ -4874,13 +4936,15 @@ def raw_generate(
|
|||||||
raise NotImplementedError("No loaded model")
|
raise NotImplementedError("No loaded model")
|
||||||
|
|
||||||
if koboldai_vars.use_colab_tpu or model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
if koboldai_vars.use_colab_tpu or model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||||
batch_out = tpu_raw_generate(
|
batch_encoded = tpu_raw_generate(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
batch_count=batch_count
|
batch_count=batch_count
|
||||||
)
|
)
|
||||||
|
elif model == "OAI":
|
||||||
|
batch_decoded = ...
|
||||||
else:
|
else:
|
||||||
batch_out = torch_raw_generate(
|
batch_encoded = torch_raw_generate(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
max_new=max_length,
|
max_new=max_length,
|
||||||
do_streaming=do_streaming,
|
do_streaming=do_streaming,
|
||||||
@@ -4888,9 +4952,21 @@ def raw_generate(
|
|||||||
batch_count=batch_count
|
batch_count=batch_count
|
||||||
)
|
)
|
||||||
|
|
||||||
decoded = tokenizer.batch_decode(batch_out[:, len(prompt_tokens):])
|
|
||||||
|
|
||||||
return [utils.decodenewlines(x) for x in decoded]
|
assert batch_encoded or batch_decoded
|
||||||
|
|
||||||
|
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||||
|
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||||
|
if batch_encoded:
|
||||||
|
batch_encoded = batch_encoded[:, len(prompt_tokens):]
|
||||||
|
|
||||||
|
if not decode:
|
||||||
|
return batch_encoded
|
||||||
|
|
||||||
|
if not batch_decoded:
|
||||||
|
batch_decoded = tokenizer.batch_decode(batch_encoded)
|
||||||
|
|
||||||
|
return [utils.decodenewlines(x) for x in batch_decoded]
|
||||||
|
|
||||||
def tpu_raw_generate(
|
def tpu_raw_generate(
|
||||||
prompt_tokens: List[int],
|
prompt_tokens: List[int],
|
||||||
@@ -4961,7 +5037,7 @@ def torch_raw_generate(
|
|||||||
# Send text to generator and deal with output
|
# Send text to generator and deal with output
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
||||||
def _generate(txt, minimum, maximum, found_entries):
|
def old_underscore_generate(txt, minimum, maximum, found_entries):
|
||||||
if(koboldai_vars.full_determinism):
|
if(koboldai_vars.full_determinism):
|
||||||
torch.manual_seed(koboldai_vars.seed)
|
torch.manual_seed(koboldai_vars.seed)
|
||||||
|
|
||||||
@@ -5000,19 +5076,30 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
)
|
)
|
||||||
already_generated += len(genout[0]) - len(gen_in[0])
|
already_generated += len(genout[0]) - len(gen_in[0])
|
||||||
assert already_generated <= koboldai_vars.genamt
|
assert already_generated <= koboldai_vars.genamt
|
||||||
|
# If we are halting, we stop
|
||||||
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
||||||
break
|
break
|
||||||
|
# if we require a generation, we continue
|
||||||
|
|
||||||
assert genout.ndim >= 2
|
assert genout.ndim >= 2
|
||||||
assert genout.shape[0] == koboldai_vars.numseqs
|
assert genout.shape[0] == koboldai_vars.numseqs
|
||||||
|
|
||||||
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols):
|
||||||
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
|
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
|
||||||
|
|
||||||
if(already_generated != koboldai_vars.generated_tkns):
|
if(already_generated != koboldai_vars.generated_tkns):
|
||||||
raise RuntimeError("WI scanning error")
|
raise RuntimeError("WI scanning error")
|
||||||
|
|
||||||
for r in range(koboldai_vars.numseqs):
|
for r in range(koboldai_vars.numseqs):
|
||||||
for c in range(already_generated):
|
for c in range(already_generated):
|
||||||
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||||
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
|
genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1]
|
||||||
|
|
||||||
encoded = []
|
encoded = []
|
||||||
|
|
||||||
|
# DYNAMIC WI:
|
||||||
|
# IF WE FIND WORLD INFO MID-GENERATION, STOP, THEN ADD WI AND ADD NEW GENERATION
|
||||||
|
|
||||||
for i in range(koboldai_vars.numseqs):
|
for i in range(koboldai_vars.numseqs):
|
||||||
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
|
||||||
@@ -5023,6 +5110,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
else:
|
else:
|
||||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||||
|
|
||||||
max_length = len(max(encoded, key=len))
|
max_length = len(max(encoded, key=len))
|
||||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||||
genout = torch.cat(
|
genout = torch.cat(
|
||||||
@@ -5032,6 +5120,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
),
|
),
|
||||||
dim=-1
|
dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
if(koboldai_vars.sp is not None):
|
if(koboldai_vars.sp is not None):
|
||||||
soft_tokens = torch.arange(
|
soft_tokens = torch.arange(
|
||||||
model.config.vocab_size,
|
model.config.vocab_size,
|
||||||
@@ -5049,7 +5138,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
return genout, already_generated
|
return genout, already_generated
|
||||||
|
|
||||||
|
|
||||||
def generate(txt, minimum, maximum, found_entries=None):
|
def old_generate(txt, minimum, maximum, found_entries=None):
|
||||||
koboldai_vars.generated_tkns = 0
|
koboldai_vars.generated_tkns = 0
|
||||||
|
|
||||||
if(found_entries is None):
|
if(found_entries is None):
|
||||||
|
Reference in New Issue
Block a user