This commit is contained in:
GuiAworld
2022-08-18 15:36:20 -03:00
5 changed files with 122 additions and 43 deletions

View File

@@ -2881,13 +2881,17 @@ def lua_compute_context(submission, entries, folders, kwargs):
force_use_txt=True,
scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True,
)
txt, _, _ = calcsubmitbudget(
len(actions),
winfo,
mem,
anotetxt,
actions,
)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text()
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
else:
txt, _, _ = calcsubmitbudget(
len(actions),
winfo,
mem,
anotetxt,
actions,
)
return utils.decodenewlines(tokenizer.decode(txt))
#==================================================================#
@@ -4458,7 +4462,11 @@ def calcsubmit(txt):
# For all transformers models
if(koboldai_vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
if koboldai_vars.alt_gen:
subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen: {}".format(tokenizer.decode(subtxt)))
else:
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
if(actionlen == 0):
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries)
@@ -4601,7 +4609,11 @@ def _generate(txt, minimum, maximum, found_entries):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars._actions)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
@@ -4998,7 +5010,11 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
txt = utils.decodenewlines(tokenizer.decode(past[i]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars._actions)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))

View File

@@ -383,6 +383,21 @@ gensettingstf = [
"sub_path": "UI",
"classname": "user",
"name": "show_probs"
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Alt Text Generation",
"id": "alttextgen",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Inserts world info entries behind text that first triggers them for better context with unlimited depth",
"menu_path": "Settings",
"sub_path": "Other",
"classname": "system",
"name": "alt_gen"
}
]
@@ -601,6 +616,21 @@ gensettingsik =[{
"sub_path": "UI",
"classname": "user",
"name": "show_probs"
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Alt Text Generation",
"id": "alttextgen",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Inserts world info entries behind text that first triggers them for better context with unlimited depth",
"menu_path": "Settings",
"sub_path": "Other",
"classname": "system",
"name": "alt_gen"
}
]

View File

@@ -36,7 +36,6 @@ def process_variable_changes(socketio, classname, name, value, old_value, debug_
socketio.emit("var_changed", {"classname": "actions", "name": "Action Count", "old_value": None, "value":value.action_count}, broadcast=True, room="UI_2")
for i in range(len(value.actions)):
print(value.actions[i])
socketio.emit("var_changed", {"classname": "story", "name": "actions", "old_value": None, "value":{"id": i, "action": value.actions[i]}}, broadcast=True, room="UI_2")
elif isinstance(value, KoboldWorldInfo):
value.send_to_ui()
@@ -109,7 +108,7 @@ class koboldai_vars(object):
def reset_model(self):
self._model_settings.reset_for_model_load()
def calc_ai_text(self):
def calc_ai_text(self, submitted_text=""):
token_budget = self.max_length
used_world_info = []
used_tokens = self.sp_length
@@ -180,7 +179,7 @@ class koboldai_vars(object):
game_text = ""
used_all_tokens = False
for i in range(len(self.actions)-1, -1, -1):
if len(self.actions) - i == self.andepth:
if len(self.actions) - i == self.andepth and self.authornote != "":
game_text = "{}{}".format(self.authornotetemplate.replace("<|>", self.authornote), game_text)
if self.actions.actions[i]["Selected Text Length"]+used_tokens <= token_budget and not used_all_tokens:
used_tokens += self.actions.actions[i]["Selected Text Length"]
@@ -212,7 +211,7 @@ class koboldai_vars(object):
used_all_tokens = True
#if we don't have enough actions to get to author's note depth then we just add it right before the game text
if len(self.actions) < self.andepth:
if len(self.actions) < self.andepth and self.authornote != "":
game_text = "{}{}".format(self.authornotetemplate.replace("<|>", self.authornote), game_text)
if not self.useprompt:
@@ -244,7 +243,11 @@ class koboldai_vars(object):
text += self.prompt
text += game_text
return text
if self.tokenizer is None:
tokens = []
else:
tokens = self.tokenizer.encode(text)
return tokens, used_tokens, used_tokens+self.genamt
def __setattr__(self, name, value):
if name[0] == "_" or name == "tokenizer":
@@ -691,6 +694,7 @@ class system_settings(settings):
self.full_determinism = False # Whether or not full determinism is enabled
self.seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file)
self.seed = None # The current RNG seed (as an int), or None if unknown
self.alt_gen = False # Use the calc_ai_text method for generating text to go to the AI
def __setattr__(self, name, value):
@@ -1040,6 +1044,8 @@ class KoboldStoryRegister(object):
if action_id in self.actions:
old_options = self.actions[action_id]["Options"]
if option_number < len(self.actions[action_id]["Options"]):
if "Probabilities" not in self.actions[action_id]["Options"][option_number]:
self.actions[action_id]["Options"][option_number]["Probabilities"] = []
self.actions[action_id]["Options"][option_number]['Probabilities'].append(probabilities)
process_variable_changes(self.socketio, "story", 'actions', {"id": action_id, 'action': self.actions[action_id]}, None)
@@ -1329,8 +1335,8 @@ class KoboldWorldInfo(object):
def reset_used_in_game(self):
for key in self.world_info:
if self.world_info[key]["used_in_game"] != constant:
self.world_info[key]["used_in_game"] = constant
if self.world_info[key]["used_in_game"] != self.world_info[key]["constant"]:
self.world_info[key]["used_in_game"] = self.world_info[key]["constant"]
self.socketio.emit("world_info_entry", self.world_info[key], broadcast=True, room="UI_2")
def set_world_info_used(self, uid):

View File

@@ -233,7 +233,7 @@ function do_story_text_updates(data) {
item.setAttribute("world_info_uids", "");
item.classList.remove("pulse")
item.scrollIntoView();
assign_world_info_to_action(action_item = item);
assign_world_info_to_action(item, null);
} else {
var span = document.createElement("span");
span.id = 'Selected Text Chunk '+data.value.id;
@@ -268,7 +268,7 @@ function do_story_text_updates(data) {
story_area.append(span);
span.scrollIntoView();
assign_world_info_to_action(action_item = span);
assign_world_info_to_action(span, null);
}
@@ -315,6 +315,49 @@ function do_story_text_length_updates(data) {
}
function do_probabilities(data) {
console.log(data);
if (document.getElementById('probabilities_'+data.value.id)) {
prob_area = document.getElementById('probabilities_'+data.value.id)
} else {
probabilities = document.getElementById('probabilities');
prob_area = document.createElement('span');
prob_area.id = 'probabilities_'+data.value.id;
probabilities.append(prob_area);
}
//Clear
while (prob_area.firstChild) {
prob_area.removeChild(prob_area.lastChild);
}
//create table
table = document.createElement("table");
table.border=1;
if ("Probabilities" in data.value.action) {
for (token of data.value.action.Probabilities) {
actual_text = document.createElement("td");
actual_text.setAttribute("rowspan", token.length);
actual_text.textContent = "Word Goes Here";
for (const [index, word] of token.entries()) {
tr = document.createElement("tr");
if (index == 0) {
tr.append(actual_text);
}
decoded = document.createElement("td");
decoded.textContent = word.decoded;
tr.append(decoded);
score = document.createElement("td");
score.textContent = (word.score*100).toFixed(2)+"%";
tr.append(score);
table.append(tr);
}
}
}
prob_area.append(table);
//prob_area.textContent = data.value.action["Probabilities"];
}
function do_presets(data) {
for (select of document.getElementsByClassName('presets')) {
//clear out the preset list
@@ -388,38 +431,18 @@ function do_ai_busy(data) {
}
function var_changed(data) {
if ((data.classname =="actions") && (data.name == 'Probabilities')) {
console.log(data);
}
//console.log({"name": data.name, "data": data});
//Special Case for Actions
if ((data.classname == "story") && (data.name == "actions")) {
console.log(data);
do_story_text_updates(data);
create_options(data);
do_story_text_length_updates(data);
do_probabilities(data);
if (data.value.action['In AI Input']) {
document.getElementById('Selected Text Chunk '+data.value.id).classList.add("within_max_length");
} else {
document.getElementById('Selected Text Chunk '+data.value.id).classList.remove("within_max_length");
}
//Special Case for Story Text
} else if ((data.classname == "actions") && (data.name == "Selected Text")) {
//do_story_text_updates(data);
//Special Case for Story Options
} else if ((data.classname == "actions") && (data.name == "Options")) {
//create_options(data);
//Special Case for Story Text Length
} else if ((data.classname == "actions") && (data.name == "Selected Text Length")) {
//do_story_text_length_updates(data);
//Special Case for Story Text Length
} else if ((data.classname == "actions") && (data.name == "In AI Input")) {
//console.log(data.value);
//if (data.value['In AI Input']) {
// document.getElementById('Selected Text Chunk '+data.value.id).classList.add("within_max_length");
//} else {
// document.getElementById('Selected Text Chunk '+data.value.id).classList.remove("within_max_length");
//}
//Special Case for Presets
} else if ((data.classname == 'model') && (data.name == 'presets')) {
do_presets(data);
@@ -1239,7 +1262,7 @@ function world_info_entry(data) {
}
$('#world_info_constant_'+data.uid).bootstrapToggle();
assign_world_info_to_action(uid=data.uid);
assign_world_info_to_action(null, data.uid);
update_token_lengths();
@@ -1902,11 +1925,11 @@ function dragend(e) {
e.preventDefault();
}
function assign_world_info_to_action(uid=null, action_item=null) {
function assign_world_info_to_action(action_item, uid) {
if (Object.keys(world_info_data).length > 0) {
if (uid != null) {
var worldinfo_to_check = {};
worldinfo_to_check[uid] = world_info_data[uid]
worldinfo_to_check[uid] = world_info_data[uid];
} else {
var worldinfo_to_check = world_info_data;
}
@@ -1915,6 +1938,7 @@ function assign_world_info_to_action(uid=null, action_item=null) {
} else {
var actions = document.getElementById("Selected Text").children;
}
for (action of actions) {
//First check to see if we have a key in the text
var words = Array.prototype.slice.call( action.children );

View File

@@ -77,6 +77,9 @@
{% endwith %}
{% endwith %}
</div>
<div id="probabilities">
</div>
</div>
<div id="setting_menu_settings" class="hidden settings_category_area">
<div class="force_center">