Fully migrated text to AI calculation to new backend. Fully supports sentence splitting so prompt, actions, etc won't get cut by partial sentences

This commit is contained in:
ebolam
2022-09-22 11:47:59 -04:00
parent faed1444ef
commit d1c2e6e506
2 changed files with 201 additions and 159 deletions

View File

@@ -3211,24 +3211,25 @@ def lua_compute_context(submission, entries, folders, kwargs):
while(folders[i] is not None): while(folders[i] is not None):
allowed_folders.add(int(folders[i])) allowed_folders.add(int(folders[i]))
i += 1 i += 1
winfo, mem, anotetxt, _ = calcsubmitbudgetheader( #winfo, mem, anotetxt, _ = calcsubmitbudgetheader(
submission, # submission,
allowed_entries=allowed_entries, # allowed_entries=allowed_entries,
allowed_folders=allowed_folders, # allowed_folders=allowed_folders,
force_use_txt=True, # force_use_txt=True,
scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True, # scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True,
) #)
if koboldai_vars.alt_gen: if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text() txt, _, _ = koboldai_vars.calc_ai_text()
print("Using Alt Gen") print("Using Alt Gen")
else: else:
txt, _, _ = calcsubmitbudget( #txt, _, _ = calcsubmitbudget(
len(actions), # len(actions),
winfo, # winfo,
mem, # mem,
anotetxt, # anotetxt,
actions, # actions,
) #)
txt, _, _ = koboldai_vars.calc_ai_text(method=1)
return utils.decodenewlines(tokenizer.decode(txt)) return utils.decodenewlines(tokenizer.decode(txt))
#==================================================================# #==================================================================#
@@ -4826,15 +4827,16 @@ def calcsubmit(txt):
anoteadded = False # In case our budget runs out before we hit A.N. depth anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(koboldai_vars.actions) actionlen = len(koboldai_vars.actions)
winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt) #winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
# For all transformers models # For all transformers models
if(koboldai_vars.model != "InferKit"): if(koboldai_vars.model != "InferKit"):
if koboldai_vars.alt_gen: if koboldai_vars.alt_gen:
subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt) subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen") logger.debug("Using Alt Gen")
else: else:
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) #subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
if(actionlen == 0): if(actionlen == 0):
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
@@ -4976,13 +4978,14 @@ def _generate(txt, minimum, maximum, found_entries):
encoded = [] encoded = []
for i in range(koboldai_vars.numseqs): for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:])) txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions) #winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
found_entries[i].update(_found_entries) found_entries[i].update(_found_entries)
if koboldai_vars.alt_gen: if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt) txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen") logger.debug("Using Alt Gen")
else: else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) #txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device)) encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len)) max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
@@ -5530,13 +5533,14 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
encoded = [] encoded = []
for i in range(koboldai_vars.numseqs): for i in range(koboldai_vars.numseqs):
txt = utils.decodenewlines(tokenizer.decode(past[i])) txt = utils.decodenewlines(tokenizer.decode(past[i]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions) #winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions)
found_entries[i].update(_found_entries) found_entries[i].update(_found_entries)
if koboldai_vars.alt_gen: if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt) txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen: {}".format(tokenizer.decode(txt))) logger.debug("Using Alt Gen")
else: else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) #txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt, method=1)
encoded.append(np.array(txt, dtype=np.uint32)) encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len)) max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded)) encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))

View File

@@ -118,7 +118,7 @@ class koboldai_vars(object):
if self.tokenizer is None: if self.tokenizer is None:
used_tokens = 99999999999999999999999 used_tokens = 99999999999999999999999
else: else:
used_tokens = 0 if self.sp_length is None else self.sp_length used_tokens = 0 if self.sp_length is None else self.sp_length + len(self.tokenizer._koboldai_header)
text = "" text = ""
# TODO: We may want to replace the "text" variable with a list-type # TODO: We may want to replace the "text" variable with a list-type
@@ -152,59 +152,18 @@ class koboldai_vars(object):
context.append({"type": "world_info", "text": wi_text}) context.append({"type": "world_info", "text": wi_text})
text += wi_text text += wi_text
#Add prompt lenght/text if we're set to always use prompt
if self.useprompt:
self.max_prompt_length if self.prompt_length > self.max_prompt_length else self.prompt_length
if prompt_length + used_tokens < token_budget:
used_tokens += self.max_prompt_length if self.prompt_length > self.max_prompt_length else self.prompt_length
#Find World Info entries in prompt
for wi in self.worldinfo_v2:
if wi['uid'] not in used_world_info:
#Check to see if we have the keys/secondary keys in the text so far
match = False
for key in wi['key']:
if key in self.prompt:
match = True
break
if wi['selective'] and match:
match = False
for key in wi['keysecondary']:
if key in self.prompt:
match=True
break
if match:
if used_tokens+0 if 'token_length' not in wi else wi['token_length'] <= token_budget:
used_tokens+=wi['token_length']
used_world_info.append(wi['uid'])
wi_text = wi['content']
context.append({"type": "world_info", "text": wi_text})
text += wi_text
self.worldinfo_v2.set_world_info_used(wi['uid'])
prompt_text = self.prompt #we're going to split our actions by sentence for better context. We'll add in which actions the sentence covers. Prompt will be added at a -1 ID
if self.tokenizer and self.prompt_length > self.max_prompt_length: actions = {i: self.actions[i] for i in range(len(self.actions))}
prompt_text = self.tokenizer.decode(self.tokenizer.encode(self.prompt)[-self.max_prompt_length-1:]) actions[-1] = self.prompt
action_text = self.prompt + str(self.actions)
text += prompt_text ###########action_text_split = [sentence, actions used in sentence, token length, included in AI context]################
context.append({"type": "prompt", "text": prompt_text}) action_text_split = [[x+" ", [], 0 if self.tokenizer is None else len(self.tokenizer.encode(x+" ")), False] for x in re.split("(?<=[.!?])\s+", action_text)]
self.prompt_in_ai = True #The last action shouldn't have the extra space from the sentence splitting, so let's remove it
else:
self.prompt_in_ai = False
#remove author's notes token length (Add the text later)
used_tokens += self.authornote_length
#Start going through the actions backwards, adding it to the text if it fits and look for world info entries
game_text = ""
game_context = []
authors_note_final = self.authornotetemplate.replace("<|>", self.authornote)
used_all_tokens = False
#we're going to split our actions by sentence for better context. We'll add in which actions the sentence covers
if self.actions.action_count >= 0:
action_text = str(self.actions)
action_text_split = [[x+" ", []] for x in re.split("(?<=[.!?])\s+", action_text)]
action_text_split[-1][0] = action_text_split[-1][0][:-1] action_text_split[-1][0] = action_text_split[-1][0][:-1]
Action_Position = [0, len(self.actions[0])] #First element is the action item, second is how much text is left action_text_split[-1][2] = 0 if self.tokenizer is None else len(self.tokenizer.encode(action_text_split[-1][0]))
Action_Position = [-1, len(actions[-1])] #First element is the action item, second is how much text is left
Sentence_Position = [0, len(action_text_split[0][0])] Sentence_Position = [0, len(action_text_split[0][0])]
while True: while True:
advance_action = False advance_action = False
@@ -225,21 +184,81 @@ class koboldai_vars(object):
if advance_action: if advance_action:
Action_Position[0] += 1 Action_Position[0] += 1
if Action_Position[0] >= len(self.actions): if Action_Position[0] >= max(actions):
break break
Action_Position[1] = len(self.actions[Action_Position[0]]) Action_Position[1] = len(actions[Action_Position[0]])
if advance_sentence: if advance_sentence:
Sentence_Position[0] += 1 Sentence_Position[0] += 1
if Sentence_Position[0] >= len(action_text_split): if Sentence_Position[0] >= len(action_text_split):
break break
Sentence_Position[1] = len(action_text_split[Sentence_Position[0]][0]) Sentence_Position[1] = len(action_text_split[Sentence_Position[0]][0])
#OK, action_text_split now contains a list of [sentence including trailing space if needed, [action IDs that sentence includes]] #OK, action_text_split now contains a list of [sentence including trailing space if needed, [action IDs that sentence includes]]
#Add prompt lenght/text if we're set to always use prompt
if self.useprompt:
prompt_length = 0
prompt_text = ""
for item in action_text_split:
if -1 not in item[1]:
#We've finished going through our prompt. Stop
break
if prompt_length + item[2] < self.max_prompt_length:
prompt_length += item[2]
item[3] = True
prompt_text += item[0]
if prompt_length + used_tokens < token_budget:
used_tokens += prompt_length
#Find World Info entries in prompt
for wi in self.worldinfo_v2:
if wi['uid'] not in used_world_info:
#Check to see if we have the keys/secondary keys in the text so far
match = False
for key in wi['key']:
if key in prompt_text:
match = True
break
if wi['selective'] and match:
match = False
for key in wi['keysecondary']:
if key in prompt_text:
match=True
break
if match:
if used_tokens+0 if 'token_length' not in wi else wi['token_length'] <= token_budget:
used_tokens+=wi['token_length']
used_world_info.append(wi['uid'])
wi_text = wi['content']
context.append({"type": "world_info", "text": wi_text})
text += wi_text
self.worldinfo_v2.set_world_info_used(wi['uid'])
prompt_text = prompt_text
if self.tokenizer and self.prompt_length > self.max_prompt_length:
prompt_text = self.tokenizer.decode(self.tokenizer.encode(self.prompt)[-self.max_prompt_length-1:])
#We'll add the prompt text AFTER we go through the game text as the world info needs to come first if we're in method 1 rather than method 2
self.prompt_in_ai = True
else:
self.prompt_in_ai = False
#remove author's notes token length (Add the text later)
used_tokens += self.authornote_length
#Start going through the actions backwards, adding it to the text if it fits and look for world info entries
game_text = ""
game_context = []
authors_note_final = self.authornotetemplate.replace("<|>", self.authornote)
used_all_tokens = False
for action in range(len(self.actions)): for action in range(len(self.actions)):
self.actions.set_action_in_ai(action, used=False) self.actions.set_action_in_ai(action, used=False)
for i in range(len(action_text_split)-1, -1, -1): for i in range(len(action_text_split)-1, -1, -1):
if action_text_split[i][3]:
#We've hit an item we've already included. Stop
break;
if len(action_text_split) - i - 1 == self.andepth and self.authornote != "": if len(action_text_split) - i - 1 == self.andepth and self.authornote != "":
game_text = "{}{}".format(authors_note_final, game_text) game_text = "{}{}".format(authors_note_final, game_text)
game_context.insert(0, {"type": "authors_note", "text": authors_note_final}) game_context.insert(0, {"type": "authors_note", "text": authors_note_final})
@@ -250,6 +269,7 @@ class koboldai_vars(object):
game_text = "{}{}".format(selected_text, game_text) game_text = "{}{}".format(selected_text, game_text)
game_context.insert(0, {"type": "action", "text": selected_text}) game_context.insert(0, {"type": "action", "text": selected_text})
for action in action_text_split[i][1]: for action in action_text_split[i][1]:
if action >= 0:
self.actions.set_action_in_ai(action) self.actions.set_action_in_ai(action)
#Now we need to check for used world info entries #Now we need to check for used world info entries
for wi in self.worldinfo_v2: for wi in self.worldinfo_v2:
@@ -266,60 +286,78 @@ class koboldai_vars(object):
if key in selected_text: if key in selected_text:
match=True match=True
break break
if method == 1:
if len(action_text_split) - i > self.widepth:
match = False
if match: if match:
if used_tokens+0 if 'token_length' not in wi or wi['token_length'] is None else wi['token_length'] <= token_budget: if used_tokens+0 if 'token_length' not in wi or wi['token_length'] is None else wi['token_length'] <= token_budget:
used_tokens+=wi['token_length'] used_tokens+=wi['token_length']
used_world_info.append(wi['uid']) used_world_info.append(wi['uid'])
wi_text = wi["content"] wi_text = wi["content"]
if method == 1:
text = "{}{}".format(wi_text, game_text)
context.insert(0, {"type": "world_info", "text": wi_text})
else:
game_text = "{}{}".format(wi_text, game_text) game_text = "{}{}".format(wi_text, game_text)
game_context.insert(0, {"type": "world_info", "text": wi_text}) game_context.insert(0, {"type": "world_info", "text": wi_text})
self.worldinfo_v2.set_world_info_used(wi['uid']) self.worldinfo_v2.set_world_info_used(wi['uid'])
else: else:
used_all_tokens = True used_all_tokens = True
else:
action_text_split = []
#if we don't have enough actions to get to author's note depth then we just add it right before the game text #if we don't have enough actions to get to author's note depth then we just add it right before the game text
if len(action_text_split) < self.andepth and self.authornote != "": if len(action_text_split) < self.andepth and self.authornote != "":
game_text = "{}{}".format(authors_note_final, game_text) game_text = "{}{}".format(authors_note_final, game_text)
game_context.insert(0, {"type": "authors_note", "text": authors_note_final}) game_context.insert(0, {"type": "authors_note", "text": authors_note_final})
if not self.useprompt: if self.useprompt:
prompt_length = self.max_prompt_length if self.prompt_length > self.max_prompt_length else self.prompt_length text += prompt_text
context.append({"type": "prompt", "text": prompt_text})
else self.useprompt:
prompt_length = 0
prompt_text = ""
for item in action_text_split:
if -1 not in item[1]:
#We've finished going through our prompt. Stop
break
if prompt_length + item[2] < self.max_prompt_length:
prompt_length += item[2]
item[3] = True
prompt_text += item[0]
if prompt_length + used_tokens < token_budget: if prompt_length + used_tokens < token_budget:
used_tokens += self.max_prompt_length if self.prompt_length > self.max_prompt_length else self.prompt_length used_tokens += prompt_length
#Find World Info entries in prompt #Find World Info entries in prompt
for wi in self.worldinfo_v2: for wi in self.worldinfo_v2:
if wi['uid'] not in used_world_info: if wi['uid'] not in used_world_info:
#Check to see if we have the keys/secondary keys in the text so far #Check to see if we have the keys/secondary keys in the text so far
match = False match = False
for key in wi['key']: for key in wi['key']:
if key in self.prompt: if key in prompt_text:
match = True match = True
break break
if wi['selective'] and match: if wi['selective'] and match:
match = False match = False
for key in wi['keysecondary']: for key in wi['keysecondary']:
if key in self.prompt: if key in prompt_text:
match=True match=True
break break
if match: if match:
if used_tokens+0 if 'token_length' not in wi or wi['token_length'] is None else wi['token_length'] <= token_budget: if used_tokens+0 if 'token_length' not in wi else wi['token_length'] <= token_budget:
used_tokens+=wi['token_length'] used_tokens+=wi['token_length']
used_world_info.append(wi['uid']) used_world_info.append(wi['uid'])
wi_text = wi["content"] wi_text = wi['content']
text += wi_text
context.append({"type": "world_info", "text": wi_text}) context.append({"type": "world_info", "text": wi_text})
text += wi_text
self.worldinfo_v2.set_world_info_used(wi['uid']) self.worldinfo_v2.set_world_info_used(wi['uid'])
self.prompt_in_ai = True
prompt_text = self.prompt prompt_text = prompt_text
if self.tokenizer and self.prompt_length > self.max_prompt_length: if self.tokenizer and self.prompt_length > self.max_prompt_length:
prompt_text = self.tokenizer.decode(self.tokenizer.encode(self.prompt)[-self.max_prompt_length-1:]) prompt_text = self.tokenizer.decode(self.tokenizer.encode(self.prompt)[-self.max_prompt_length-1:])
else:
self.prompt_in_ai = False
prompt_text = ""
text += prompt_text text += prompt_text
context.append({"type": "prompt", "text": prompt_text}) context.append({"type": "prompt", "text": prompt_text})
self.prompt_in_ai = True
else:
self.prompt_in_ai = False
text += game_text text += game_text
context += game_context context += game_context