mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Final touches
This commit is contained in:
21
aiserver.py
21
aiserver.py
@@ -4826,9 +4826,6 @@ def calcsubmit(txt):
|
|||||||
# Send it!
|
# Send it!
|
||||||
ikrequest(subtxt)
|
ikrequest(subtxt)
|
||||||
|
|
||||||
def __debug(*args):
|
|
||||||
print("[DBG] ", *args)
|
|
||||||
|
|
||||||
def core_generate(text: list, min: int, max: int, found_entries: set):
|
def core_generate(text: list, min: int, max: int, found_entries: set):
|
||||||
# This generation function is tangled with koboldai_vars intentionally. It
|
# This generation function is tangled with koboldai_vars intentionally. It
|
||||||
# is meant for the story and nothing else.
|
# is meant for the story and nothing else.
|
||||||
@@ -4866,13 +4863,11 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
|
|||||||
|
|
||||||
koboldai_vars._prompt = koboldai_vars.prompt
|
koboldai_vars._prompt = koboldai_vars.prompt
|
||||||
|
|
||||||
__debug("generate core", text)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
already_generated = 0
|
already_generated = 0
|
||||||
numseqs = koboldai_vars.numseqs
|
numseqs = koboldai_vars.numseqs
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
__debug("generate loop start", text)
|
|
||||||
# The reason this is a loop is due to how Dynamic WI works. We
|
# The reason this is a loop is due to how Dynamic WI works. We
|
||||||
# cannot simply add the WI to the context mid-generation, so we
|
# cannot simply add the WI to the context mid-generation, so we
|
||||||
# stop early, and then insert WI, then continue generating. That
|
# stop early, and then insert WI, then continue generating. That
|
||||||
@@ -4888,15 +4883,12 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
|
|||||||
bypass_hf_maxlength=True,
|
bypass_hf_maxlength=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
__debug("generate result", result.__dict__)
|
|
||||||
|
|
||||||
genout = result.encoded
|
genout = result.encoded
|
||||||
|
|
||||||
already_generated += len(genout[0]) - 1 # - len(gen_in[0])
|
already_generated += len(genout[0]) - 1
|
||||||
assert already_generated <= koboldai_vars.genamt
|
assert already_generated <= koboldai_vars.genamt
|
||||||
|
|
||||||
if result.is_whole_generation:
|
if result.is_whole_generation:
|
||||||
__debug("Outa here")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# Generation stopped; why?
|
# Generation stopped; why?
|
||||||
@@ -4953,13 +4945,9 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
|
|||||||
)
|
)
|
||||||
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
|
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
|
||||||
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
|
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
|
||||||
# diff = genout.shape[-1] - gen_in.shape[-1]
|
|
||||||
# minimum += diff
|
|
||||||
# maximum += diff
|
|
||||||
gen_in = genout
|
gen_in = genout
|
||||||
numseqs = 1
|
numseqs = 1
|
||||||
|
|
||||||
__debug("final out", genout, "already_gen", already_generated)
|
|
||||||
return genout, already_generated
|
return genout, already_generated
|
||||||
|
|
||||||
class GenerationResult:
|
class GenerationResult:
|
||||||
@@ -4975,10 +4963,8 @@ class GenerationResult:
|
|||||||
# Controls if we should trim output by prompt length
|
# Controls if we should trim output by prompt length
|
||||||
output_includes_prompt: bool = False,
|
output_includes_prompt: bool = False,
|
||||||
):
|
):
|
||||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
# Shave prompt off of encoded response when needed (HF). Decoded does
|
||||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
# not return prompt.
|
||||||
print("shape", out_batches.shape)
|
|
||||||
|
|
||||||
if output_includes_prompt:
|
if output_includes_prompt:
|
||||||
self.encoded = out_batches[:, len(prompt) - 1:]
|
self.encoded = out_batches[:, len(prompt) - 1:]
|
||||||
else:
|
else:
|
||||||
@@ -5065,7 +5051,6 @@ def tpu_raw_generate(
|
|||||||
|
|
||||||
# Mostly lifted from apiactionsubmit_tpumtjgenerate
|
# Mostly lifted from apiactionsubmit_tpumtjgenerate
|
||||||
soft_tokens = tpumtjgetsofttokens()
|
soft_tokens = tpumtjgetsofttokens()
|
||||||
__debug("we are generating with", prompt_tokens, "batch", batch_count, "soft tokens", soft_tokens)
|
|
||||||
|
|
||||||
genout = tpool.execute(
|
genout = tpool.execute(
|
||||||
tpu_mtj_backend.infer_static,
|
tpu_mtj_backend.infer_static,
|
||||||
|
Reference in New Issue
Block a user