diff --git a/aiserver.py b/aiserver.py index 67635a57..b219dfd2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -4534,32 +4534,27 @@ def calcsubmit(txt): print("Using Alt Gen") else: subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) - if(actionlen == 0): - if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): - generate(subtxt, min, max, found_entries=found_entries) - elif(koboldai_vars.model == "Colab"): - sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.model == "API"): - sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.model == "CLUSTER"): - sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.model == "OAI"): - oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): - tpumtjgenerate(subtxt, min, max, found_entries=found_entries) - else: - if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): - generate(subtxt, min, max, found_entries=found_entries) - elif(koboldai_vars.model == "Colab"): - sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.model == "API"): - sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.model == "CLUSTER"): - sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.model == "OAI"): - oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): - tpumtjgenerate(subtxt, min, max, found_entries=found_entries) + + if not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ( + "Colab", + "API", + "CLUSTER", + "OAI", + "TPUMeshTransformerGPTJ", + "TPUMeshTransformerGPTNeoX" + ): + legacy_generate(subtxt, min, max) + # generate(subtxt, min, max, found_entries=found_entries) + elif koboldai_vars.model == "Colab": + sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) + elif koboldai_vars.model == "API": + sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) + elif koboldai_vars.model == "CLUSTER": + sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) + elif koboldai_vars.model == "OAI": + oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) + elif koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): + tpumtjgenerate(subtxt, min, max, found_entries=found_entries) # For InferKit web API else: @@ -4617,18 +4612,28 @@ def calcsubmit(txt): # Send it! ikrequest(subtxt) -def legacy_generate(text): +def legacy_generate(text: Union[str, list], min: int, max: int): # Architected after oairequest koboldai_vars.lastctx = text - outputs = raw_generate(text) + print("Pregen") + print(koboldai_vars.max_length) + outputs = raw_generate( + text, + max_length=koboldai_vars.genamt, + do_streaming=True + ) + print(f"postgen: {outputs}") # Lua bridge, genmod for i, output in enumerate(outputs): koboldai_vars.lua_koboldbridge.outputs[i + 1] = output + print("post lua") + execute_genmod() + print("post genmod") if koboldai_vars.lua_koboldbridge.regeneration_required: koboldai_vars.lua_koboldbridge.regeneration_required = False @@ -4639,10 +4644,14 @@ def legacy_generate(text): assert isinstance(out, str) else: genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs] + + print("post assign genout") koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout]) genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()] + print("post genout assign") + if len(genout) == 1: genresult(genout[0]["generated_text"]) else: @@ -4651,7 +4660,9 @@ def legacy_generate(text): genresult(genout[restart_seq - 1]["generated_text"]) else: genselect(genout) + print("post whatever that is") set_aibusy(0) + print("post busy") def raw_generate( # prompt is either a string (text) or a list (token ids) @@ -4697,7 +4708,7 @@ def raw_generate( return [utils.decodenewlines(tokenizer.decode(x)) for x in batch_out] def tpu_raw_generate( - prompt_tokens: list[int], + prompt_tokens: List[int], max_length: int, batch_count: int, ): @@ -4725,7 +4736,7 @@ def tpu_raw_generate( return genout def torch_raw_generate( - prompt_tokens: list[int], + prompt_tokens: List[int], max_length: int, do_streaming: bool = False, @@ -4748,7 +4759,10 @@ def torch_raw_generate( device = breakmodel.primary_device gen_in = gen_in.to(device) + print("okay...") + with torch.no_grad(): + print(f"in {max_length}") genout = generator( gen_in, do_sample=True, @@ -4758,8 +4772,10 @@ def torch_raw_generate( use_cache=True, num_return_sequences=batch_count, ) + print("out") + print("wtf") - return genout[0] + return genout #==================================================================# # Send text to generator and deal with output @@ -7984,7 +8000,6 @@ def UI_2_generate_raw(): except NotImplementedError as e: return Response(json.dumps({"error": str(e)}), status=500) - print(f"{out=}") return out #==================================================================#