mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Gen gen gen
This commit is contained in:
31
aiserver.py
31
aiserver.py
@@ -1807,11 +1807,13 @@ def patch_transformers():
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if not koboldai_vars.inference_config.do_dynamic_wi:
|
||||
if not koboldai_vars.inference_config.do_streaming:
|
||||
return False
|
||||
|
||||
if not koboldai_vars.output_streaming:
|
||||
return False
|
||||
|
||||
print([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
|
||||
|
||||
koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids])
|
||||
|
||||
@@ -4617,23 +4619,17 @@ def legacy_generate(text: Union[str, list], min: int, max: int):
|
||||
|
||||
koboldai_vars.lastctx = text
|
||||
|
||||
print("Pregen")
|
||||
print(koboldai_vars.max_length)
|
||||
outputs = raw_generate(
|
||||
text,
|
||||
max_length=koboldai_vars.genamt,
|
||||
do_streaming=True
|
||||
)
|
||||
print(f"postgen: {outputs}")
|
||||
|
||||
# Lua bridge, genmod
|
||||
for i, output in enumerate(outputs):
|
||||
koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
||||
|
||||
print("post lua")
|
||||
|
||||
execute_genmod()
|
||||
print("post genmod")
|
||||
|
||||
if koboldai_vars.lua_koboldbridge.regeneration_required:
|
||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
@@ -4644,14 +4640,10 @@ def legacy_generate(text: Union[str, list], min: int, max: int):
|
||||
assert isinstance(out, str)
|
||||
else:
|
||||
genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs]
|
||||
|
||||
print("post assign genout")
|
||||
|
||||
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||
|
||||
print("post genout assign")
|
||||
|
||||
if len(genout) == 1:
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
@@ -4660,9 +4652,7 @@ def legacy_generate(text: Union[str, list], min: int, max: int):
|
||||
genresult(genout[restart_seq - 1]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
print("post whatever that is")
|
||||
set_aibusy(0)
|
||||
print("post busy")
|
||||
|
||||
def raw_generate(
|
||||
# prompt is either a string (text) or a list (token ids)
|
||||
@@ -4699,13 +4689,15 @@ def raw_generate(
|
||||
else:
|
||||
batch_out = torch_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_length=max_length,
|
||||
max_new=max_length,
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
batch_count=batch_count
|
||||
)
|
||||
|
||||
decoded = tokenizer.batch_decode(batch_out[:, len(prompt_tokens):])
|
||||
|
||||
return [utils.decodenewlines(tokenizer.decode(x)) for x in batch_out]
|
||||
return [utils.decodenewlines(x) for x in decoded]
|
||||
|
||||
def tpu_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
@@ -4737,7 +4729,7 @@ def tpu_raw_generate(
|
||||
|
||||
def torch_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
max_length: int,
|
||||
max_new: int,
|
||||
|
||||
do_streaming: bool = False,
|
||||
do_dynamic_wi: bool = False,
|
||||
@@ -4759,21 +4751,16 @@ def torch_raw_generate(
|
||||
device = breakmodel.primary_device
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
print("okay...")
|
||||
|
||||
with torch.no_grad():
|
||||
print(f"in {max_length}")
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
max_length=max_length,
|
||||
max_length=min(len(prompt_tokens) + max_new, koboldai_vars.max_length),
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=koboldai_vars.badwordsids,
|
||||
use_cache=True,
|
||||
num_return_sequences=batch_count,
|
||||
)
|
||||
print("out")
|
||||
print("wtf")
|
||||
|
||||
return genout
|
||||
|
||||
|
Reference in New Issue
Block a user