diff --git a/aiserver.py b/aiserver.py index 6e1efe27..8bcdf608 100644 --- a/aiserver.py +++ b/aiserver.py @@ -4874,21 +4874,28 @@ def core_generate(text: list, min: int, max: int, found_entries: set): with torch.no_grad(): already_generated = 0 numseqs = koboldai_vars.numseqs - while True: + + do_loop = True + + while do_loop: # The reason this is a loop is due to how Dynamic WI works. We # cannot simply add the WI to the context mid-generation, so we # stop early, and then insert WI, then continue generating. That # stopping and continuing is this loop. - genout = raw_generate( + result = raw_generate( gen_in, - # Real max length is handled by CoreStopper. - max_length=int(2e9), + max_length=koboldai_vars.genamt, do_streaming=True, do_dynamic_wi=True, batch_count=numseqs, - decode=False, + # Real max length is handled by CoreStopper. + bypass_hf_maxlength=True, ) + + do_loop = not result.is_whole_generation + genout = result.encoded + already_generated += len(genout[0]) - len(gen_in[0]) assert already_generated <= koboldai_vars.genamt @@ -4954,6 +4961,25 @@ def core_generate(text: list, min: int, max: int, found_entries: set): return genout, already_generated +class GenerationResult: + def __init__( + self, + out_batches: list, + prompt: list, + + # Controls if generate() does it's looping thing. This should only be + # done for HF models that use that StoppingCondition + is_whole_generation: bool + ): + # Shave prompt off of encoded response. Decoded does not return prompt. + # TODO: Does MTJ generation shave this off automatically? Test it! + print("shape", out_batches.shape) + self.encoded = out_batches[:, len(prompt) - 1:] + self.prompt = prompt + self.is_whole_generation = is_whole_generation + + self.decoded = [utils.decodenewlines(tokenizer.decode(enc)) for enc in self.encoded] + def raw_generate( # prompt is either a string (text) or a list (token ids) prompt: Union[str, list], @@ -4962,28 +4988,17 @@ def raw_generate( do_streaming: bool = False, do_dynamic_wi: bool = False, batch_count: int = 1, - decode: bool = True, -) -> List: + bypass_hf_maxlength: bool = False, +) -> GenerationResult: - if isinstance(prompt, str): - prompt_decoded = prompt - prompt_tokens = tokenizer.encode(prompt) - else: - prompt_decoded = tokenizer.decode(prompt) - prompt_tokens = prompt + prompt_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt - # Some gen methods such as OAI don't return tokens. - batch_decoded = None - batch_encoded = None - if koboldai_vars.model == "Colab": raise NotImplementedError("Colab API raw_generate unsupported") elif koboldai_vars.model == "API": raise NotImplementedError("API raw_generate unsupported") elif koboldai_vars.model == "CLUSTER": raise NotImplementedError("Cluster raw_generate unsupported") - elif koboldai_vars.model == "OAI": - raise NotImplementedError("OpenAI raw_generate unsupported") elif koboldai_vars.model == "ReadOnly": raise NotImplementedError("No loaded model") @@ -4993,31 +5008,30 @@ def raw_generate( max_length=max_length, batch_count=batch_count ) + return GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False + ) elif model == "OAI": - batch_decoded = ... - else: - batch_encoded = torch_raw_generate( + batch_encoded = oai_raw_generate( prompt_tokens=prompt_tokens, - max_new=max_length, - do_streaming=do_streaming, - do_dynamic_wi=do_dynamic_wi, + max_length=max_length, batch_count=batch_count ) - - assert batch_encoded is not None or batch_decoded is not None - - # Shave prompt off of encoded response. Decoded does not return prompt. - # TODO: Does MTJ generation shave this off automatically? Test it! - if batch_encoded is not None: - batch_encoded = batch_encoded[:, len(prompt_tokens) - 1:] + return GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False + ) - if not decode: - return batch_encoded - - if batch_decoded is None: - batch_decoded = tokenizer.batch_decode(batch_encoded) - - return [utils.decodenewlines(x) for x in batch_decoded] + # Torch HF + batch_encoded = torch_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_length if not bypass_hf_maxlength else int(2e9), + do_streaming=do_streaming, + do_dynamic_wi=do_dynamic_wi, + batch_count=batch_count + ) + return GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True + ) def tpu_raw_generate( prompt_tokens: List[int], @@ -5090,6 +5104,84 @@ def torch_raw_generate( return genout +def oai_raw_generate( + prompt_tokens: List[int], + max_length: int, + batch_count: int, +): + # Taken mainly from oairequest() + + decoded_prompt = utils.decodenewlines(tokenizer.decode(prompt_tokens)) + + # Log request to console + if not koboldai_vars.quiet: + print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(decoded_prompt), decoded_prompt, colors.END)) + + # Store context in memory to use it for comparison with generated content + koboldai_vars.lastctx = decoded_prompt + + # Build request JSON data + # GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround + # as the koboldai_vars.model will always be OAI + if 'GooseAI' in koboldai_vars.configname: + reqdata = { + 'prompt': decoded_prompt, + 'max_tokens': koboldai_vars.genamt, + 'temperature': koboldai_vars.temp, + 'top_a': koboldai_vars.top_a, + 'top_p': koboldai_vars.top_p, + 'top_k': koboldai_vars.top_k, + 'tfs': koboldai_vars.tfs, + 'typical_p': koboldai_vars.typical, + 'repetition_penalty': koboldai_vars.rep_pen, + 'repetition_penalty_slope': koboldai_vars.rep_pen_slope, + 'repetition_penalty_range': koboldai_vars.rep_pen_range, + 'n': koboldai_vars.numseqs, + # TODO: Implement streaming + 'stream': False + } + else: + reqdata = { + 'prompt': decoded_prompt, + 'max_tokens': koboldai_vars.genamt, + 'temperature': koboldai_vars.temp, + 'top_p': koboldai_vars.top_p, + 'n': koboldai_vars.numseqs, + 'stream': False + } + + req = requests.post( + koboldai_vars.oaiurl, + json = reqdata, + headers = { + 'Authorization': 'Bearer '+koboldai_vars.oaiapikey, + 'Content-Type': 'application/json' + } + ) + + # Deal with the response + if(req.status_code == 200): + outputs = [out["text"] for out in req.json()["choices"]] + + decoded_genout = [{"generated_text": utils.decodenewlines(txt)} + for txt in outputs] + + if not koboldai_vars.quiet: + print("{0}{1}{2}".format(colors.CYAN, decoded_genout, colors.END)) + + return [tokenizer.encode(x) for x in decoded_genout] + else: + # Send error message to web client + er = req.json() + if("error" in er): + type = er["error"]["type"] + message = er["error"]["message"] + + errmsg = "OpenAI API Error: {0} - {1}".format(type, message) + emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1") + set_aibusy(0) + return [] + #==================================================================# # Send text to generator and deal with output #==================================================================# @@ -6323,102 +6415,6 @@ def ikrequest(txt): emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1") set_aibusy(0) -#==================================================================# -# Assembles game data into a request to OpenAI API -#==================================================================# -def oairequest(txt, min, max): - # Log request to console - if not koboldai_vars.quiet: - print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END)) - - # Store context in memory to use it for comparison with generated content - koboldai_vars.lastctx = txt - - # Build request JSON data - # GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround - # as the koboldai_vars.model will always be OAI - if 'GooseAI' in koboldai_vars.configname: - reqdata = { - 'prompt': txt, - 'max_tokens': koboldai_vars.genamt, - 'temperature': koboldai_vars.temp, - 'top_a': koboldai_vars.top_a, - 'top_p': koboldai_vars.top_p, - 'top_k': koboldai_vars.top_k, - 'tfs': koboldai_vars.tfs, - 'typical_p': koboldai_vars.typical, - 'repetition_penalty': koboldai_vars.rep_pen, - 'repetition_penalty_slope': koboldai_vars.rep_pen_slope, - 'repetition_penalty_range': koboldai_vars.rep_pen_range, - 'n': koboldai_vars.numseqs, - 'stream': False - } - else: - reqdata = { - 'prompt': txt, - 'max_tokens': koboldai_vars.genamt, - 'temperature': koboldai_vars.temp, - 'top_p': koboldai_vars.top_p, - 'n': koboldai_vars.numseqs, - 'stream': False - } - - req = requests.post( - koboldai_vars.oaiurl, - json = reqdata, - headers = { - 'Authorization': 'Bearer '+koboldai_vars.oaiapikey, - 'Content-Type': 'application/json' - } - ) - - # Deal with the response - if(req.status_code == 200): - outputs = [out["text"] for out in req.json()["choices"]] - - for idx in range(len(outputs)): - koboldai_vars.lua_koboldbridge.outputs[idx+1] = outputs[idx] - - execute_outmod() - if (koboldai_vars.lua_koboldbridge.regeneration_required): - koboldai_vars.lua_koboldbridge.regeneration_required = False - genout = [] - for i in range(len(outputs)): - genout.append( - {"generated_text": koboldai_vars.lua_koboldbridge.outputs[i + 1]}) - assert type(genout[-1]["generated_text"]) is str - else: - genout = [ - {"generated_text": utils.decodenewlines(txt)} - for txt in outputs] - - koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout]) - genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()] - if (len(genout) == 1): - genresult(genout[0]["generated_text"]) - else: - if (koboldai_vars.lua_koboldbridge.restart_sequence is not None and - koboldai_vars.lua_koboldbridge.restart_sequence > 0): - genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence - 1][ - "generated_text"]) - else: - genselect(genout) - - if not koboldai_vars.quiet: - print("{0}{1}{2}".format(colors.CYAN, genout, colors.END)) - - set_aibusy(0) - else: - # Send error message to web client - er = req.json() - if("error" in er): - type = er["error"]["type"] - message = er["error"]["message"] - - errmsg = "OpenAI API Error: {0} - {1}".format(type, message) - emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1") - set_aibusy(0) - #==================================================================# # Forces UI to Play mode #==================================================================#