mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Gen 10000
This commit is contained in:
266
aiserver.py
266
aiserver.py
@@ -4874,21 +4874,28 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
|
||||
with torch.no_grad():
|
||||
already_generated = 0
|
||||
numseqs = koboldai_vars.numseqs
|
||||
while True:
|
||||
|
||||
do_loop = True
|
||||
|
||||
while do_loop:
|
||||
# The reason this is a loop is due to how Dynamic WI works. We
|
||||
# cannot simply add the WI to the context mid-generation, so we
|
||||
# stop early, and then insert WI, then continue generating. That
|
||||
# stopping and continuing is this loop.
|
||||
|
||||
genout = raw_generate(
|
||||
result = raw_generate(
|
||||
gen_in,
|
||||
# Real max length is handled by CoreStopper.
|
||||
max_length=int(2e9),
|
||||
max_length=koboldai_vars.genamt,
|
||||
do_streaming=True,
|
||||
do_dynamic_wi=True,
|
||||
batch_count=numseqs,
|
||||
decode=False,
|
||||
# Real max length is handled by CoreStopper.
|
||||
bypass_hf_maxlength=True,
|
||||
)
|
||||
|
||||
do_loop = not result.is_whole_generation
|
||||
genout = result.encoded
|
||||
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
assert already_generated <= koboldai_vars.genamt
|
||||
|
||||
@@ -4954,6 +4961,25 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
|
||||
|
||||
return genout, already_generated
|
||||
|
||||
class GenerationResult:
|
||||
def __init__(
|
||||
self,
|
||||
out_batches: list,
|
||||
prompt: list,
|
||||
|
||||
# Controls if generate() does it's looping thing. This should only be
|
||||
# done for HF models that use that StoppingCondition
|
||||
is_whole_generation: bool
|
||||
):
|
||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||
print("shape", out_batches.shape)
|
||||
self.encoded = out_batches[:, len(prompt) - 1:]
|
||||
self.prompt = prompt
|
||||
self.is_whole_generation = is_whole_generation
|
||||
|
||||
self.decoded = [utils.decodenewlines(tokenizer.decode(enc)) for enc in self.encoded]
|
||||
|
||||
def raw_generate(
|
||||
# prompt is either a string (text) or a list (token ids)
|
||||
prompt: Union[str, list],
|
||||
@@ -4962,28 +4988,17 @@ def raw_generate(
|
||||
do_streaming: bool = False,
|
||||
do_dynamic_wi: bool = False,
|
||||
batch_count: int = 1,
|
||||
decode: bool = True,
|
||||
) -> List:
|
||||
bypass_hf_maxlength: bool = False,
|
||||
) -> GenerationResult:
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_decoded = prompt
|
||||
prompt_tokens = tokenizer.encode(prompt)
|
||||
else:
|
||||
prompt_decoded = tokenizer.decode(prompt)
|
||||
prompt_tokens = prompt
|
||||
prompt_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt
|
||||
|
||||
# Some gen methods such as OAI don't return tokens.
|
||||
batch_decoded = None
|
||||
batch_encoded = None
|
||||
|
||||
if koboldai_vars.model == "Colab":
|
||||
raise NotImplementedError("Colab API raw_generate unsupported")
|
||||
elif koboldai_vars.model == "API":
|
||||
raise NotImplementedError("API raw_generate unsupported")
|
||||
elif koboldai_vars.model == "CLUSTER":
|
||||
raise NotImplementedError("Cluster raw_generate unsupported")
|
||||
elif koboldai_vars.model == "OAI":
|
||||
raise NotImplementedError("OpenAI raw_generate unsupported")
|
||||
elif koboldai_vars.model == "ReadOnly":
|
||||
raise NotImplementedError("No loaded model")
|
||||
|
||||
@@ -4993,31 +5008,30 @@ def raw_generate(
|
||||
max_length=max_length,
|
||||
batch_count=batch_count
|
||||
)
|
||||
return GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False
|
||||
)
|
||||
elif model == "OAI":
|
||||
batch_decoded = ...
|
||||
else:
|
||||
batch_encoded = torch_raw_generate(
|
||||
batch_encoded = oai_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_length,
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
max_length=max_length,
|
||||
batch_count=batch_count
|
||||
)
|
||||
|
||||
assert batch_encoded is not None or batch_decoded is not None
|
||||
|
||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||
if batch_encoded is not None:
|
||||
batch_encoded = batch_encoded[:, len(prompt_tokens) - 1:]
|
||||
return GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False
|
||||
)
|
||||
|
||||
if not decode:
|
||||
return batch_encoded
|
||||
|
||||
if batch_decoded is None:
|
||||
batch_decoded = tokenizer.batch_decode(batch_encoded)
|
||||
|
||||
return [utils.decodenewlines(x) for x in batch_decoded]
|
||||
# Torch HF
|
||||
batch_encoded = torch_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_length if not bypass_hf_maxlength else int(2e9),
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
batch_count=batch_count
|
||||
)
|
||||
return GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||
)
|
||||
|
||||
def tpu_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
@@ -5090,6 +5104,84 @@ def torch_raw_generate(
|
||||
|
||||
return genout
|
||||
|
||||
def oai_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
max_length: int,
|
||||
batch_count: int,
|
||||
):
|
||||
# Taken mainly from oairequest()
|
||||
|
||||
decoded_prompt = utils.decodenewlines(tokenizer.decode(prompt_tokens))
|
||||
|
||||
# Log request to console
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(decoded_prompt), decoded_prompt, colors.END))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
koboldai_vars.lastctx = decoded_prompt
|
||||
|
||||
# Build request JSON data
|
||||
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
||||
# as the koboldai_vars.model will always be OAI
|
||||
if 'GooseAI' in koboldai_vars.configname:
|
||||
reqdata = {
|
||||
'prompt': decoded_prompt,
|
||||
'max_tokens': koboldai_vars.genamt,
|
||||
'temperature': koboldai_vars.temp,
|
||||
'top_a': koboldai_vars.top_a,
|
||||
'top_p': koboldai_vars.top_p,
|
||||
'top_k': koboldai_vars.top_k,
|
||||
'tfs': koboldai_vars.tfs,
|
||||
'typical_p': koboldai_vars.typical,
|
||||
'repetition_penalty': koboldai_vars.rep_pen,
|
||||
'repetition_penalty_slope': koboldai_vars.rep_pen_slope,
|
||||
'repetition_penalty_range': koboldai_vars.rep_pen_range,
|
||||
'n': koboldai_vars.numseqs,
|
||||
# TODO: Implement streaming
|
||||
'stream': False
|
||||
}
|
||||
else:
|
||||
reqdata = {
|
||||
'prompt': decoded_prompt,
|
||||
'max_tokens': koboldai_vars.genamt,
|
||||
'temperature': koboldai_vars.temp,
|
||||
'top_p': koboldai_vars.top_p,
|
||||
'n': koboldai_vars.numseqs,
|
||||
'stream': False
|
||||
}
|
||||
|
||||
req = requests.post(
|
||||
koboldai_vars.oaiurl,
|
||||
json = reqdata,
|
||||
headers = {
|
||||
'Authorization': 'Bearer '+koboldai_vars.oaiapikey,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
)
|
||||
|
||||
# Deal with the response
|
||||
if(req.status_code == 200):
|
||||
outputs = [out["text"] for out in req.json()["choices"]]
|
||||
|
||||
decoded_genout = [{"generated_text": utils.decodenewlines(txt)}
|
||||
for txt in outputs]
|
||||
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}{1}{2}".format(colors.CYAN, decoded_genout, colors.END))
|
||||
|
||||
return [tokenizer.encode(x) for x in decoded_genout]
|
||||
else:
|
||||
# Send error message to web client
|
||||
er = req.json()
|
||||
if("error" in er):
|
||||
type = er["error"]["type"]
|
||||
message = er["error"]["message"]
|
||||
|
||||
errmsg = "OpenAI API Error: {0} - {1}".format(type, message)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1")
|
||||
set_aibusy(0)
|
||||
return []
|
||||
|
||||
#==================================================================#
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
@@ -6323,102 +6415,6 @@ def ikrequest(txt):
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1")
|
||||
set_aibusy(0)
|
||||
|
||||
#==================================================================#
|
||||
# Assembles game data into a request to OpenAI API
|
||||
#==================================================================#
|
||||
def oairequest(txt, min, max):
|
||||
# Log request to console
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
koboldai_vars.lastctx = txt
|
||||
|
||||
# Build request JSON data
|
||||
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
||||
# as the koboldai_vars.model will always be OAI
|
||||
if 'GooseAI' in koboldai_vars.configname:
|
||||
reqdata = {
|
||||
'prompt': txt,
|
||||
'max_tokens': koboldai_vars.genamt,
|
||||
'temperature': koboldai_vars.temp,
|
||||
'top_a': koboldai_vars.top_a,
|
||||
'top_p': koboldai_vars.top_p,
|
||||
'top_k': koboldai_vars.top_k,
|
||||
'tfs': koboldai_vars.tfs,
|
||||
'typical_p': koboldai_vars.typical,
|
||||
'repetition_penalty': koboldai_vars.rep_pen,
|
||||
'repetition_penalty_slope': koboldai_vars.rep_pen_slope,
|
||||
'repetition_penalty_range': koboldai_vars.rep_pen_range,
|
||||
'n': koboldai_vars.numseqs,
|
||||
'stream': False
|
||||
}
|
||||
else:
|
||||
reqdata = {
|
||||
'prompt': txt,
|
||||
'max_tokens': koboldai_vars.genamt,
|
||||
'temperature': koboldai_vars.temp,
|
||||
'top_p': koboldai_vars.top_p,
|
||||
'n': koboldai_vars.numseqs,
|
||||
'stream': False
|
||||
}
|
||||
|
||||
req = requests.post(
|
||||
koboldai_vars.oaiurl,
|
||||
json = reqdata,
|
||||
headers = {
|
||||
'Authorization': 'Bearer '+koboldai_vars.oaiapikey,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
)
|
||||
|
||||
# Deal with the response
|
||||
if(req.status_code == 200):
|
||||
outputs = [out["text"] for out in req.json()["choices"]]
|
||||
|
||||
for idx in range(len(outputs)):
|
||||
koboldai_vars.lua_koboldbridge.outputs[idx+1] = outputs[idx]
|
||||
|
||||
execute_outmod()
|
||||
if (koboldai_vars.lua_koboldbridge.regeneration_required):
|
||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(len(outputs)):
|
||||
genout.append(
|
||||
{"generated_text": koboldai_vars.lua_koboldbridge.outputs[i + 1]})
|
||||
assert type(genout[-1]["generated_text"]) is str
|
||||
else:
|
||||
genout = [
|
||||
{"generated_text": utils.decodenewlines(txt)}
|
||||
for txt in outputs]
|
||||
|
||||
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||
if (len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
if (koboldai_vars.lua_koboldbridge.restart_sequence is not None and
|
||||
koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence - 1][
|
||||
"generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
|
||||
set_aibusy(0)
|
||||
else:
|
||||
# Send error message to web client
|
||||
er = req.json()
|
||||
if("error" in er):
|
||||
type = er["error"]["type"]
|
||||
message = er["error"]["message"]
|
||||
|
||||
errmsg = "OpenAI API Error: {0} - {1}".format(type, message)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1")
|
||||
set_aibusy(0)
|
||||
|
||||
#==================================================================#
|
||||
# Forces UI to Play mode
|
||||
#==================================================================#
|
||||
|
Reference in New Issue
Block a user