diff --git a/.gitignore b/.gitignore
index b97d1d30..90669874 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,6 +25,8 @@ softprompts
models
!models/models go here.txt
Uninstall
+flask_session
+accelerate-disk-cache
.ipynb_checkpoints
# Ignore PyCharm project files.
diff --git a/aiserver.py b/aiserver.py
index 564d6ce0..9d34fe9c 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -224,7 +224,7 @@ class vars:
model_type = "" # Model Type (Automatically taken from the model config)
noai = False # Runs the script without starting up the transformers pipeline
aibusy = False # Stops submissions while the AI is working
- max_length = 2048 # Maximum number of tokens to submit per action
+ max_length = 1024 # Maximum number of tokens to submit per action
ikmax = 3000 # Maximum number of characters to submit to InferKit
genamt = 80 # Amount of text for each action to generate
ikgen = 200 # Number of characters for InferKit to generate
@@ -646,6 +646,11 @@ def move_model_to_devices(model):
import breakmodel
if(utils.HAS_ACCELERATE):
+ import accelerate.utils
+ for key, value in model.state_dict().items():
+ target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
+ if(value.dtype is not target_dtype):
+ accelerate.utils.set_module_tensor_to_device(model, key, target_dtype)
disk_blocks = breakmodel.disk_blocks
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks)
@@ -5544,9 +5549,6 @@ def loadRequest(loadpath, filename=None):
ln = len(vars.actions[vars.actions.get_last_key()].rstrip())
footer += vars.actions[vars.actions.get_last_key()][ln:]
vars.actions[vars.actions.get_last_key()] = vars.actions[vars.actions.get_last_key()][:ln]
- if(len(vars.actions) == 0):
- vars.gamestarted = False
-
# Try not to break older save files
if("authorsnote" in js):
diff --git a/environments/finetuneanon.yml b/environments/finetuneanon.yml
index b49f0bd7..5411d2ce 100644
--- a/environments/finetuneanon.yml
+++ b/environments/finetuneanon.yml
@@ -6,6 +6,7 @@ channels:
dependencies:
- colorama
- flask-socketio
+ - flask-session
- pytorch
- cudatoolkit=11.1
- tensorflow-gpu
diff --git a/environments/huggingface.yml b/environments/huggingface.yml
index 205d5e31..f24c5336 100644
--- a/environments/huggingface.yml
+++ b/environments/huggingface.yml
@@ -6,6 +6,7 @@ channels:
dependencies:
- colorama
- flask-socketio
+ - flask-session
- pytorch=1.11.*
- python=3.8.*
- cudatoolkit=11.1
diff --git a/environments/rocm-finetune.yml b/environments/rocm-finetune.yml
index 5672ed21..60b17d98 100644
--- a/environments/rocm-finetune.yml
+++ b/environments/rocm-finetune.yml
@@ -5,6 +5,7 @@ channels:
dependencies:
- colorama
- flask-socketio
+ - flask-session
- python=3.8.*
- eventlet
- markdown
diff --git a/environments/rocm.yml b/environments/rocm.yml
index 8ade341f..4778208d 100644
--- a/environments/rocm.yml
+++ b/environments/rocm.yml
@@ -5,6 +5,7 @@ channels:
dependencies:
- colorama
- flask-socketio
+ - flask-session
- python=3.8.*
- eventlet
- markdown
diff --git a/gensettings.py b/gensettings.py
index 3d188b16..636b7985 100644
--- a/gensettings.py
+++ b/gensettings.py
@@ -17,7 +17,7 @@ gensettingstf = [
"id": "settemp",
"min": 0.1,
"max": 2.0,
- "step": 0.05,
+ "step": 0.01,
"default": 0.5,
"tooltip": "Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious."
},
@@ -28,7 +28,7 @@ gensettingstf = [
"id": "settopp",
"min": 0.0,
"max": 1.0,
- "step": 0.05,
+ "step": 0.01,
"default": 0.9,
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect)"
},
@@ -50,7 +50,7 @@ gensettingstf = [
"id": "settfs",
"min": 0.0,
"max": 1.0,
- "step": 0.05,
+ "step": 0.01,
"default": 1.0,
"tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)"
},
@@ -61,7 +61,7 @@ gensettingstf = [
"id": "settypical",
"min": 0.0,
"max": 1.0,
- "step": 0.05,
+ "step": 0.01,
"default": 1.0,
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
},
diff --git a/maps/bloom.json b/maps/bloom.json
new file mode 100644
index 00000000..e3f5feb9
--- /dev/null
+++ b/maps/bloom.json
@@ -0,0 +1,30 @@
+{
+ "mtj_compat": "bloom",
+ "mtj_pe": "alibi",
+ "mtj_config_map": {
+ "d_model": "n_embed",
+ "n_heads": "num_attention_heads",
+ "layers": "n_layer"
+ },
+ "static_weights": {
+ "word_embeddings.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
+ "word_embeddings_layernorm.weight": {"mtj": {"module": "embedding_shard/~/replicated_layer_norm", "param": "scale"}},
+ "word_embeddings_layernorm.bias": {"mtj": {"module": "embedding_shard/~/replicated_layer_norm", "param": "offset"}},
+ "ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
+ "ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}
+ },
+ "layer_weights": {
+ "h.{layer}.self_attention.query_key_value.weight": {"mtj": {"module": "layer_{layer}/~/combined_qkv", "param": "w"}},
+ "h.{layer}.self_attention.query_key_value.bias": {"mtj": {"module": "layer_{layer}/~/combined_qkv", "param": "b"}},
+ "h.{layer}.self_attention.dense.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
+ "h.{layer}.self_attention.dense.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
+ "h.{layer}.mlp.dense_h_to_4h.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
+ "h.{layer}.mlp.dense_h_to_4h.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
+ "h.{layer}.mlp.dense_4h_to_h.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
+ "h.{layer}.mlp.dense_4h_to_h.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
+ "h.{layer}.input_layernorm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
+ "h.{layer}.input_layernorm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
+ "h.{layer}.post_attention_layernorm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
+ "h.{layer}.post_attention_layernorm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
+ }
+}
diff --git a/requirements_mtj.txt b/requirements_mtj.txt
index d80604f6..eb7cf79c 100644
--- a/requirements_mtj.txt
+++ b/requirements_mtj.txt
@@ -5,6 +5,7 @@ requests
optax >= 0.0.5, <= 0.0.9
dm-haiku == 0.0.5
jax == 0.2.21
+jaxlib >= 0.1.69, <= 0.3.7
transformers >= 4.19
progressbar2
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
diff --git a/static/application.js b/static/application.js
index aef93daa..2388aa23 100644
--- a/static/application.js
+++ b/static/application.js
@@ -87,6 +87,7 @@ var wiscroll = 0;
var editmode = false;
var connected = false;
var newly_loaded = true;
+var all_modified_chunks = new Set();
var modified_chunks = new Set();
var empty_chunks = new Set();
var gametext_bound = false;
@@ -129,6 +130,7 @@ var adventure = false;
var chatmode = false;
var sliders_throttle = getThrottle(200);
+var submit_throttle = null;
//=================================================================//
// METHODS
@@ -892,6 +894,17 @@ function dosubmit(disallow_abort) {
return;
}
chunkOnFocusOut("override");
+ // Wait for editor changes to be applied before submitting
+ submit_throttle = getThrottle(70);
+ submit_throttle.txt = txt;
+ submit_throttle.disallow_abort = disallow_abort;
+ submit_throttle(0, _dosubmit);
+}
+
+function _dosubmit() {
+ var txt = submit_throttle.txt;
+ var disallow_abort = submit_throttle.disallow_abort;
+ submit_throttle = null;
input_text.val("");
hideMessage();
hidegenseqs();
@@ -1523,14 +1536,30 @@ function chunkOnTextInput(event) {
r.deleteContents();
}
- // In Chrome the added
will go outside of the chunks if we press
+ // In Chrome and Safari the added
will go outside of the chunks if we press
// enter at the end of the story in the editor, so this is here
// to put the
back in the right place
var br = $("#_EDITOR_LINEBREAK_")[0];
if(br.parentNode === game_text[0]) {
+ var parent = br.previousSibling;
if(br.previousSibling.nodeType !== 1) {
+ parent = br.previousSibling.previousSibling;
br.previousSibling.previousSibling.appendChild(br.previousSibling);
}
+ if(parent.lastChild.tagName === "BR") {
+ parent.lastChild.remove(); // Chrome and Safari also insert an extra
in this case for some reason so we need to remove it
+ if(using_webkit_patch) {
+ // Safari on iOS has a bug where it selects all text in the last chunk of the story when this happens so we collapse the selection to the end of the chunk in that case
+ setTimeout(function() {
+ var s = getSelection();
+ var r = s.getRangeAt(0);
+ r.selectNodeContents(parent);
+ r.collapse(false);
+ s.removeAllRanges();
+ s.addRange(r);
+ }, 2);
+ }
+ }
br.previousSibling.appendChild(br);
r.selectNodeContents(br.parentNode);
s.removeAllRanges();
@@ -1712,6 +1741,7 @@ function applyChunkDeltas(nodes) {
var chunks = Array.from(buildChunkSetFromNodeArray(nodes));
for(var i = 0; i < chunks.length; i++) {
modified_chunks.add(chunks[i]);
+ all_modified_chunks.add(chunks[i]);
}
setTimeout(function() {
var chunks = Array.from(modified_chunks);
@@ -1722,12 +1752,18 @@ function applyChunkDeltas(nodes) {
if(!selected_chunks.has(chunks[i])) {
modified_chunks.delete(chunks[i]);
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(chunk)});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
}
empty_chunks.delete(chunks[i]);
} else {
if(!selected_chunks.has(chunks[i])) {
modified_chunks.delete(chunks[i]);
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(chunk)});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
}
empty_chunks.add(chunks[i]);
}
@@ -1749,6 +1785,9 @@ function syncAllModifiedChunks(including_selected_chunks=false) {
empty_chunks.delete(chunks[i]);
}
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': data});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
}
}
}
@@ -1801,10 +1840,16 @@ function restorePrompt() {
if(this.innerText.trim().length) {
saved_prompt = this.innerText.trim();
socket.send({'cmd': 'inlinedelete', 'data': this.getAttribute("n")});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
this.parentNode.removeChild(this);
return false;
}
socket.send({'cmd': 'inlinedelete', 'data': this.getAttribute("n")});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
this.parentNode.removeChild(this);
});
}
@@ -1819,6 +1864,9 @@ function restorePrompt() {
modified_chunks.delete('0');
empty_chunks.delete('0');
socket.send({'cmd': 'inlineedit', 'chunk': '0', 'data': saved_prompt});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
}
function deleteEmptyChunks() {
@@ -1830,13 +1878,21 @@ function deleteEmptyChunks() {
restorePrompt();
} else {
socket.send({'cmd': 'inlinedelete', 'data': chunks[i]});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
}
}
if(modified_chunks.has('0')) {
modified_chunks.delete(chunks[i]);
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(document.getElementById("n0"))});
+ if(submit_throttle !== null) {
+ submit_throttle(0, _dosubmit);
+ }
+ }
+ if(gamestarted) {
+ saved_prompt = formatChunkInnerText($("#n0")[0]);
}
- saved_prompt = formatChunkInnerText($("#n0")[0]);
}
function highlightEditingChunks() {
@@ -1860,11 +1916,29 @@ function highlightEditingChunks() {
}
function cleanupChunkWhitespace() {
+ unbindGametext();
+
+ var chunks = Array.from(all_modified_chunks);
+ for(var i = 0; i < chunks.length; i++) {
+ var original_chunk = document.getElementById("n" + chunks[i]);
+ if(original_chunk === null || original_chunk.innerText.trim().length === 0) {
+ all_modified_chunks.delete(chunks[i]);
+ modified_chunks.delete(chunks[i]);
+ empty_chunks.add(chunks[i]);
+ }
+ }
+
// Merge empty chunks with the next chunk
var chunks = Array.from(empty_chunks);
chunks.sort(function(e) {parseInt(e)});
for(var i = 0; i < chunks.length; i++) {
+ if(chunks[i] == "0") {
+ continue;
+ }
var original_chunk = document.getElementById("n" + chunks[i]);
+ if(original_chunk === null) {
+ continue;
+ }
var chunk = original_chunk.nextSibling;
while(chunk) {
if(chunk.tagName === "CHUNK") {
@@ -1874,11 +1948,14 @@ function cleanupChunkWhitespace() {
}
if(chunk) {
chunk.innerText = original_chunk.innerText + chunk.innerText;
+ if(original_chunk.innerText.length != 0 && !modified_chunks.has(chunk.getAttribute("n"))) {
+ modified_chunks.add(chunk.getAttribute("n"));
+ }
}
original_chunk.innerText = "";
}
// Move whitespace at the end of non-empty chunks into the beginning of the next non-empty chunk
- var chunks = Array.from(modified_chunks);
+ var chunks = Array.from(all_modified_chunks);
chunks.sort(function(e) {parseInt(e)});
for(var i = 0; i < chunks.length; i++) {
var original_chunk = document.getElementById("n" + chunks[i]);
@@ -1892,9 +1969,14 @@ function cleanupChunkWhitespace() {
var ln = original_chunk.innerText.trimEnd().length;
if (chunk) {
chunk.innerText = original_chunk.innerText.substring(ln) + chunk.innerText;
+ if(ln != original_chunk.innerText.length && !modified_chunks.has(chunk.getAttribute("n"))) {
+ modified_chunks.add(chunk.getAttribute("n"));
+ }
}
original_chunk.innerText = original_chunk.innerText.substring(0, ln);
}
+
+ bindGametext();
}
// This gets run every time the text in a chunk is edited
@@ -1976,6 +2058,7 @@ function chunkOnFocusOut(event) {
return;
}
cleanupChunkWhitespace();
+ all_modified_chunks = new Set();
syncAllModifiedChunks(true);
setTimeout(function() {
var blurred = game_text[0] !== document.activeElement;
@@ -2185,6 +2268,7 @@ $(document).ready(function(){
unbindGametext();
allowedit = gamestarted && $("#allowediting").prop('checked');
game_text.attr('contenteditable', allowedit);
+ all_modified_chunks = new Set();
modified_chunks = new Set();
empty_chunks = new Set();
game_text.html(msg.data);
@@ -2739,6 +2823,12 @@ $(document).ready(function(){
chunkOnFocusOut
);
mutation_observer = new MutationObserver(chunkOnDOMMutate);
+ $("#gamescreen").on('click', function(e) {
+ if(this !== e.target) {
+ return;
+ }
+ document.activeElement.blur();
+ });
// This is required for the editor to work correctly in Firefox on desktop
// because the gods of HTML and JavaScript say so
diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py
index 5e633c83..9e411261 100644
--- a/torch_lazy_loader.py
+++ b/torch_lazy_loader.py
@@ -52,7 +52,7 @@ import pickle
import torch
import utils
from torch.nn import Module
-from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, Optional, Tuple, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@@ -73,7 +73,7 @@ STORAGE_TYPE_MAP = {
class LazyTensor:
- def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
+ def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type
self.key = key
self.location = location
diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py
index 8daa8dee..7b0f6807 100644
--- a/tpu_mtj_backend.py
+++ b/tpu_mtj_backend.py
@@ -1246,13 +1246,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if utils.num_shards is not None:
utils.current_shard += 1
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
+ model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None)
# Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because
# transformers will throw a tantrum otherwise. To attain
# the least possible memory usage, we create them as meta
# tensors, which don't take up any actual CPU or TPU memory.
- if key not in model_spec:
+ if model_spec_key is None:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
utils.bar.update(1)
continue
@@ -1267,7 +1268,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if current_offset != model_dict[key].seek_offset:
f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset
- spec = model_spec[key]
+ spec = model_spec[model_spec_key]
transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}"
diff --git a/userscripts/kaipreset_basic_phrase_bias.lua b/userscripts/kaipreset_basic_phrase_bias.lua
index b5176e55..2a846923 100644
--- a/userscripts/kaipreset_basic_phrase_bias.lua
+++ b/userscripts/kaipreset_basic_phrase_bias.lua
@@ -183,8 +183,8 @@ function userscript.genmod()
max_overlap[i] = 0
local s = {}
local z = {[0] = 0}
- local l = 1
- local r = 1
+ local l = 0
+ local r = 0
local n_s = math.min(n_tokens, bias_entry.n_tokens)
local j = 0
for k = 1, n_s do