mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Gen 10000
This commit is contained in:
266
aiserver.py
266
aiserver.py
@@ -4874,21 +4874,28 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
already_generated = 0
|
already_generated = 0
|
||||||
numseqs = koboldai_vars.numseqs
|
numseqs = koboldai_vars.numseqs
|
||||||
while True:
|
|
||||||
|
do_loop = True
|
||||||
|
|
||||||
|
while do_loop:
|
||||||
# The reason this is a loop is due to how Dynamic WI works. We
|
# The reason this is a loop is due to how Dynamic WI works. We
|
||||||
# cannot simply add the WI to the context mid-generation, so we
|
# cannot simply add the WI to the context mid-generation, so we
|
||||||
# stop early, and then insert WI, then continue generating. That
|
# stop early, and then insert WI, then continue generating. That
|
||||||
# stopping and continuing is this loop.
|
# stopping and continuing is this loop.
|
||||||
|
|
||||||
genout = raw_generate(
|
result = raw_generate(
|
||||||
gen_in,
|
gen_in,
|
||||||
# Real max length is handled by CoreStopper.
|
max_length=koboldai_vars.genamt,
|
||||||
max_length=int(2e9),
|
|
||||||
do_streaming=True,
|
do_streaming=True,
|
||||||
do_dynamic_wi=True,
|
do_dynamic_wi=True,
|
||||||
batch_count=numseqs,
|
batch_count=numseqs,
|
||||||
decode=False,
|
# Real max length is handled by CoreStopper.
|
||||||
|
bypass_hf_maxlength=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
do_loop = not result.is_whole_generation
|
||||||
|
genout = result.encoded
|
||||||
|
|
||||||
already_generated += len(genout[0]) - len(gen_in[0])
|
already_generated += len(genout[0]) - len(gen_in[0])
|
||||||
assert already_generated <= koboldai_vars.genamt
|
assert already_generated <= koboldai_vars.genamt
|
||||||
|
|
||||||
@@ -4954,6 +4961,25 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
|
|||||||
|
|
||||||
return genout, already_generated
|
return genout, already_generated
|
||||||
|
|
||||||
|
class GenerationResult:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_batches: list,
|
||||||
|
prompt: list,
|
||||||
|
|
||||||
|
# Controls if generate() does it's looping thing. This should only be
|
||||||
|
# done for HF models that use that StoppingCondition
|
||||||
|
is_whole_generation: bool
|
||||||
|
):
|
||||||
|
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||||
|
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||||
|
print("shape", out_batches.shape)
|
||||||
|
self.encoded = out_batches[:, len(prompt) - 1:]
|
||||||
|
self.prompt = prompt
|
||||||
|
self.is_whole_generation = is_whole_generation
|
||||||
|
|
||||||
|
self.decoded = [utils.decodenewlines(tokenizer.decode(enc)) for enc in self.encoded]
|
||||||
|
|
||||||
def raw_generate(
|
def raw_generate(
|
||||||
# prompt is either a string (text) or a list (token ids)
|
# prompt is either a string (text) or a list (token ids)
|
||||||
prompt: Union[str, list],
|
prompt: Union[str, list],
|
||||||
@@ -4962,19 +4988,10 @@ def raw_generate(
|
|||||||
do_streaming: bool = False,
|
do_streaming: bool = False,
|
||||||
do_dynamic_wi: bool = False,
|
do_dynamic_wi: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
decode: bool = True,
|
bypass_hf_maxlength: bool = False,
|
||||||
) -> List:
|
) -> GenerationResult:
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
prompt_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt
|
||||||
prompt_decoded = prompt
|
|
||||||
prompt_tokens = tokenizer.encode(prompt)
|
|
||||||
else:
|
|
||||||
prompt_decoded = tokenizer.decode(prompt)
|
|
||||||
prompt_tokens = prompt
|
|
||||||
|
|
||||||
# Some gen methods such as OAI don't return tokens.
|
|
||||||
batch_decoded = None
|
|
||||||
batch_encoded = None
|
|
||||||
|
|
||||||
if koboldai_vars.model == "Colab":
|
if koboldai_vars.model == "Colab":
|
||||||
raise NotImplementedError("Colab API raw_generate unsupported")
|
raise NotImplementedError("Colab API raw_generate unsupported")
|
||||||
@@ -4982,8 +4999,6 @@ def raw_generate(
|
|||||||
raise NotImplementedError("API raw_generate unsupported")
|
raise NotImplementedError("API raw_generate unsupported")
|
||||||
elif koboldai_vars.model == "CLUSTER":
|
elif koboldai_vars.model == "CLUSTER":
|
||||||
raise NotImplementedError("Cluster raw_generate unsupported")
|
raise NotImplementedError("Cluster raw_generate unsupported")
|
||||||
elif koboldai_vars.model == "OAI":
|
|
||||||
raise NotImplementedError("OpenAI raw_generate unsupported")
|
|
||||||
elif koboldai_vars.model == "ReadOnly":
|
elif koboldai_vars.model == "ReadOnly":
|
||||||
raise NotImplementedError("No loaded model")
|
raise NotImplementedError("No loaded model")
|
||||||
|
|
||||||
@@ -4993,31 +5008,30 @@ def raw_generate(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
batch_count=batch_count
|
batch_count=batch_count
|
||||||
)
|
)
|
||||||
|
return GenerationResult(
|
||||||
|
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False
|
||||||
|
)
|
||||||
elif model == "OAI":
|
elif model == "OAI":
|
||||||
batch_decoded = ...
|
batch_encoded = oai_raw_generate(
|
||||||
else:
|
|
||||||
batch_encoded = torch_raw_generate(
|
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
max_new=max_length,
|
max_length=max_length,
|
||||||
do_streaming=do_streaming,
|
|
||||||
do_dynamic_wi=do_dynamic_wi,
|
|
||||||
batch_count=batch_count
|
batch_count=batch_count
|
||||||
)
|
)
|
||||||
|
return GenerationResult(
|
||||||
|
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False
|
||||||
|
)
|
||||||
|
|
||||||
assert batch_encoded is not None or batch_decoded is not None
|
# Torch HF
|
||||||
|
batch_encoded = torch_raw_generate(
|
||||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
prompt_tokens=prompt_tokens,
|
||||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
max_new=max_length if not bypass_hf_maxlength else int(2e9),
|
||||||
if batch_encoded is not None:
|
do_streaming=do_streaming,
|
||||||
batch_encoded = batch_encoded[:, len(prompt_tokens) - 1:]
|
do_dynamic_wi=do_dynamic_wi,
|
||||||
|
batch_count=batch_count
|
||||||
if not decode:
|
)
|
||||||
return batch_encoded
|
return GenerationResult(
|
||||||
|
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||||
if batch_decoded is None:
|
)
|
||||||
batch_decoded = tokenizer.batch_decode(batch_encoded)
|
|
||||||
|
|
||||||
return [utils.decodenewlines(x) for x in batch_decoded]
|
|
||||||
|
|
||||||
def tpu_raw_generate(
|
def tpu_raw_generate(
|
||||||
prompt_tokens: List[int],
|
prompt_tokens: List[int],
|
||||||
@@ -5090,6 +5104,84 @@ def torch_raw_generate(
|
|||||||
|
|
||||||
return genout
|
return genout
|
||||||
|
|
||||||
|
def oai_raw_generate(
|
||||||
|
prompt_tokens: List[int],
|
||||||
|
max_length: int,
|
||||||
|
batch_count: int,
|
||||||
|
):
|
||||||
|
# Taken mainly from oairequest()
|
||||||
|
|
||||||
|
decoded_prompt = utils.decodenewlines(tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
# Log request to console
|
||||||
|
if not koboldai_vars.quiet:
|
||||||
|
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(decoded_prompt), decoded_prompt, colors.END))
|
||||||
|
|
||||||
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
koboldai_vars.lastctx = decoded_prompt
|
||||||
|
|
||||||
|
# Build request JSON data
|
||||||
|
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
||||||
|
# as the koboldai_vars.model will always be OAI
|
||||||
|
if 'GooseAI' in koboldai_vars.configname:
|
||||||
|
reqdata = {
|
||||||
|
'prompt': decoded_prompt,
|
||||||
|
'max_tokens': koboldai_vars.genamt,
|
||||||
|
'temperature': koboldai_vars.temp,
|
||||||
|
'top_a': koboldai_vars.top_a,
|
||||||
|
'top_p': koboldai_vars.top_p,
|
||||||
|
'top_k': koboldai_vars.top_k,
|
||||||
|
'tfs': koboldai_vars.tfs,
|
||||||
|
'typical_p': koboldai_vars.typical,
|
||||||
|
'repetition_penalty': koboldai_vars.rep_pen,
|
||||||
|
'repetition_penalty_slope': koboldai_vars.rep_pen_slope,
|
||||||
|
'repetition_penalty_range': koboldai_vars.rep_pen_range,
|
||||||
|
'n': koboldai_vars.numseqs,
|
||||||
|
# TODO: Implement streaming
|
||||||
|
'stream': False
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
reqdata = {
|
||||||
|
'prompt': decoded_prompt,
|
||||||
|
'max_tokens': koboldai_vars.genamt,
|
||||||
|
'temperature': koboldai_vars.temp,
|
||||||
|
'top_p': koboldai_vars.top_p,
|
||||||
|
'n': koboldai_vars.numseqs,
|
||||||
|
'stream': False
|
||||||
|
}
|
||||||
|
|
||||||
|
req = requests.post(
|
||||||
|
koboldai_vars.oaiurl,
|
||||||
|
json = reqdata,
|
||||||
|
headers = {
|
||||||
|
'Authorization': 'Bearer '+koboldai_vars.oaiapikey,
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deal with the response
|
||||||
|
if(req.status_code == 200):
|
||||||
|
outputs = [out["text"] for out in req.json()["choices"]]
|
||||||
|
|
||||||
|
decoded_genout = [{"generated_text": utils.decodenewlines(txt)}
|
||||||
|
for txt in outputs]
|
||||||
|
|
||||||
|
if not koboldai_vars.quiet:
|
||||||
|
print("{0}{1}{2}".format(colors.CYAN, decoded_genout, colors.END))
|
||||||
|
|
||||||
|
return [tokenizer.encode(x) for x in decoded_genout]
|
||||||
|
else:
|
||||||
|
# Send error message to web client
|
||||||
|
er = req.json()
|
||||||
|
if("error" in er):
|
||||||
|
type = er["error"]["type"]
|
||||||
|
message = er["error"]["message"]
|
||||||
|
|
||||||
|
errmsg = "OpenAI API Error: {0} - {1}".format(type, message)
|
||||||
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1")
|
||||||
|
set_aibusy(0)
|
||||||
|
return []
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Send text to generator and deal with output
|
# Send text to generator and deal with output
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@@ -6323,102 +6415,6 @@ def ikrequest(txt):
|
|||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1")
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
|
||||||
#==================================================================#
|
|
||||||
# Assembles game data into a request to OpenAI API
|
|
||||||
#==================================================================#
|
|
||||||
def oairequest(txt, min, max):
|
|
||||||
# Log request to console
|
|
||||||
if not koboldai_vars.quiet:
|
|
||||||
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END))
|
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
|
||||||
koboldai_vars.lastctx = txt
|
|
||||||
|
|
||||||
# Build request JSON data
|
|
||||||
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
|
||||||
# as the koboldai_vars.model will always be OAI
|
|
||||||
if 'GooseAI' in koboldai_vars.configname:
|
|
||||||
reqdata = {
|
|
||||||
'prompt': txt,
|
|
||||||
'max_tokens': koboldai_vars.genamt,
|
|
||||||
'temperature': koboldai_vars.temp,
|
|
||||||
'top_a': koboldai_vars.top_a,
|
|
||||||
'top_p': koboldai_vars.top_p,
|
|
||||||
'top_k': koboldai_vars.top_k,
|
|
||||||
'tfs': koboldai_vars.tfs,
|
|
||||||
'typical_p': koboldai_vars.typical,
|
|
||||||
'repetition_penalty': koboldai_vars.rep_pen,
|
|
||||||
'repetition_penalty_slope': koboldai_vars.rep_pen_slope,
|
|
||||||
'repetition_penalty_range': koboldai_vars.rep_pen_range,
|
|
||||||
'n': koboldai_vars.numseqs,
|
|
||||||
'stream': False
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
reqdata = {
|
|
||||||
'prompt': txt,
|
|
||||||
'max_tokens': koboldai_vars.genamt,
|
|
||||||
'temperature': koboldai_vars.temp,
|
|
||||||
'top_p': koboldai_vars.top_p,
|
|
||||||
'n': koboldai_vars.numseqs,
|
|
||||||
'stream': False
|
|
||||||
}
|
|
||||||
|
|
||||||
req = requests.post(
|
|
||||||
koboldai_vars.oaiurl,
|
|
||||||
json = reqdata,
|
|
||||||
headers = {
|
|
||||||
'Authorization': 'Bearer '+koboldai_vars.oaiapikey,
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deal with the response
|
|
||||||
if(req.status_code == 200):
|
|
||||||
outputs = [out["text"] for out in req.json()["choices"]]
|
|
||||||
|
|
||||||
for idx in range(len(outputs)):
|
|
||||||
koboldai_vars.lua_koboldbridge.outputs[idx+1] = outputs[idx]
|
|
||||||
|
|
||||||
execute_outmod()
|
|
||||||
if (koboldai_vars.lua_koboldbridge.regeneration_required):
|
|
||||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
|
||||||
genout = []
|
|
||||||
for i in range(len(outputs)):
|
|
||||||
genout.append(
|
|
||||||
{"generated_text": koboldai_vars.lua_koboldbridge.outputs[i + 1]})
|
|
||||||
assert type(genout[-1]["generated_text"]) is str
|
|
||||||
else:
|
|
||||||
genout = [
|
|
||||||
{"generated_text": utils.decodenewlines(txt)}
|
|
||||||
for txt in outputs]
|
|
||||||
|
|
||||||
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
|
||||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
|
||||||
if (len(genout) == 1):
|
|
||||||
genresult(genout[0]["generated_text"])
|
|
||||||
else:
|
|
||||||
if (koboldai_vars.lua_koboldbridge.restart_sequence is not None and
|
|
||||||
koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
|
||||||
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence - 1][
|
|
||||||
"generated_text"])
|
|
||||||
else:
|
|
||||||
genselect(genout)
|
|
||||||
|
|
||||||
if not koboldai_vars.quiet:
|
|
||||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
|
||||||
|
|
||||||
set_aibusy(0)
|
|
||||||
else:
|
|
||||||
# Send error message to web client
|
|
||||||
er = req.json()
|
|
||||||
if("error" in er):
|
|
||||||
type = er["error"]["type"]
|
|
||||||
message = er["error"]["message"]
|
|
||||||
|
|
||||||
errmsg = "OpenAI API Error: {0} - {1}".format(type, message)
|
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True, room="UI_1")
|
|
||||||
set_aibusy(0)
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Forces UI to Play mode
|
# Forces UI to Play mode
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
Reference in New Issue
Block a user