Bug fix for probabilities (fixes issue 344)

This commit is contained in:
ebolam
2022-12-16 20:05:15 -05:00
parent eabbd73bf3
commit 802bef8c37

View File

@@ -1538,13 +1538,14 @@ class KoboldStoryRegister(object):
self.actions[action_id]["Selected Text"] = text self.actions[action_id]["Selected Text"] = text
self.actions[action_id]["Time"] = self.actions[action_id].get("Time", int(time.time())) self.actions[action_id]["Time"] = self.actions[action_id].get("Time", int(time.time()))
if 'Probabilities' in self.actions[action_id]: if 'Probabilities' in self.actions[action_id]:
tokens = self.koboldai_vars.tokenizer.encode(text) if self.koboldai_vars.tokenizer is not None:
for token_num in range(len(self.actions[action_id]["Probabilities"])): tokens = self.koboldai_vars.tokenizer.encode(text)
for token_option in range(len(self.actions[action_id]["Probabilities"][token_num])): for token_num in range(len(self.actions[action_id]["Probabilities"])):
if token_num < len(tokens): for token_option in range(len(self.actions[action_id]["Probabilities"][token_num])):
self.actions[action_id]["Probabilities"][token_num][token_option]["Used"] = tokens[token_num] == self.actions[action_id]["Probabilities"][token_num][token_option]["tokenId"] if token_num < len(tokens):
else: self.actions[action_id]["Probabilities"][token_num][token_option]["Used"] = tokens[token_num] == self.actions[action_id]["Probabilities"][token_num][token_option]["tokenId"]
self.actions[action_id]["Probabilities"][token_num][token_option]["Used"] = False else:
self.actions[action_id]["Probabilities"][token_num][token_option]["Used"] = False
selected_text_length = 0 selected_text_length = 0
self.actions[action_id]["Selected Text Length"] = selected_text_length self.actions[action_id]["Selected Text Length"] = selected_text_length
for item in self.actions[action_id]["Options"]: for item in self.actions[action_id]["Options"]:
@@ -1591,13 +1592,14 @@ class KoboldStoryRegister(object):
del item['stream_id'] del item['stream_id']
found = True found = True
if 'Probabilities' in item: if 'Probabilities' in item:
tokens = self.koboldai_vars.tokenizer.encode(option) if self.koboldai_vars.tokenizer is not None:
for token_num in range(len(item["Probabilities"])): tokens = self.koboldai_vars.tokenizer.encode(option)
for token_option in range(len(item["Probabilities"][token_num])): for token_num in range(len(item["Probabilities"])):
if token_num < len(tokens): for token_option in range(len(item["Probabilities"][token_num])):
item["Probabilities"][token_num][token_option]["Used"] = tokens[token_num] == item["Probabilities"][token_num][token_option]["tokenId"] if token_num < len(tokens):
else: item["Probabilities"][token_num][token_option]["Used"] = tokens[token_num] == item["Probabilities"][token_num][token_option]["tokenId"]
item["Probabilities"][token_num][token_option]["Used"] = False else:
item["Probabilities"][token_num][token_option]["Used"] = False
break break
elif item['text'] == option: elif item['text'] == option:
found = True found = True
@@ -1605,13 +1607,14 @@ class KoboldStoryRegister(object):
del item['stream_id'] del item['stream_id']
found = True found = True
if 'Probabilities' in item: if 'Probabilities' in item:
tokens = self.koboldai_vars.tokenizer.encode(option) if self.koboldai_vars.tokenizer is not None:
for token_num in range(len(item["Probabilities"])): tokens = self.koboldai_vars.tokenizer.encode(option)
for token_option in range(len(item["Probabilities"][token_num])): for token_num in range(len(item["Probabilities"])):
if token_num < len(tokens): for token_option in range(len(item["Probabilities"][token_num])):
item["Probabilities"][token_num][token_option]["Used"] = tokens[token_num] == item["Probabilities"][token_num][token_option]["tokenId"] if token_num < len(tokens):
else: item["Probabilities"][token_num][token_option]["Used"] = tokens[token_num] == item["Probabilities"][token_num][token_option]["tokenId"]
item["Probabilities"][token_num][token_option]["Used"] = False else:
item["Probabilities"][token_num][token_option]["Used"] = False
break break
if not found: if not found: