Merge branch 'united' of https://github.com/ebolam/KoboldAI into united

This commit is contained in:
ebolam 2022-07-15 12:30:18 -04:00
commit 68d143b80c
13 changed files with 147 additions and 17 deletions

2
.gitignore vendored
View File

@ -25,6 +25,8 @@ softprompts
models models
!models/models go here.txt !models/models go here.txt
Uninstall Uninstall
flask_session
accelerate-disk-cache
.ipynb_checkpoints .ipynb_checkpoints
# Ignore PyCharm project files. # Ignore PyCharm project files.

View File

@ -224,7 +224,7 @@ class vars:
model_type = "" # Model Type (Automatically taken from the model config) model_type = "" # Model Type (Automatically taken from the model config)
noai = False # Runs the script without starting up the transformers pipeline noai = False # Runs the script without starting up the transformers pipeline
aibusy = False # Stops submissions while the AI is working aibusy = False # Stops submissions while the AI is working
max_length = 2048 # Maximum number of tokens to submit per action max_length = 1024 # Maximum number of tokens to submit per action
ikmax = 3000 # Maximum number of characters to submit to InferKit ikmax = 3000 # Maximum number of characters to submit to InferKit
genamt = 80 # Amount of text for each action to generate genamt = 80 # Amount of text for each action to generate
ikgen = 200 # Number of characters for InferKit to generate ikgen = 200 # Number of characters for InferKit to generate
@ -646,6 +646,11 @@ def move_model_to_devices(model):
import breakmodel import breakmodel
if(utils.HAS_ACCELERATE): if(utils.HAS_ACCELERATE):
import accelerate.utils
for key, value in model.state_dict().items():
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
if(value.dtype is not target_dtype):
accelerate.utils.set_module_tensor_to_device(model, key, target_dtype)
disk_blocks = breakmodel.disk_blocks disk_blocks = breakmodel.disk_blocks
gpu_blocks = breakmodel.gpu_blocks gpu_blocks = breakmodel.gpu_blocks
ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks) ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks)
@ -5544,9 +5549,6 @@ def loadRequest(loadpath, filename=None):
ln = len(vars.actions[vars.actions.get_last_key()].rstrip()) ln = len(vars.actions[vars.actions.get_last_key()].rstrip())
footer += vars.actions[vars.actions.get_last_key()][ln:] footer += vars.actions[vars.actions.get_last_key()][ln:]
vars.actions[vars.actions.get_last_key()] = vars.actions[vars.actions.get_last_key()][:ln] vars.actions[vars.actions.get_last_key()] = vars.actions[vars.actions.get_last_key()][:ln]
if(len(vars.actions) == 0):
vars.gamestarted = False
# Try not to break older save files # Try not to break older save files
if("authorsnote" in js): if("authorsnote" in js):

View File

@ -6,6 +6,7 @@ channels:
dependencies: dependencies:
- colorama - colorama
- flask-socketio - flask-socketio
- flask-session
- pytorch - pytorch
- cudatoolkit=11.1 - cudatoolkit=11.1
- tensorflow-gpu - tensorflow-gpu

View File

@ -6,6 +6,7 @@ channels:
dependencies: dependencies:
- colorama - colorama
- flask-socketio - flask-socketio
- flask-session
- pytorch=1.11.* - pytorch=1.11.*
- python=3.8.* - python=3.8.*
- cudatoolkit=11.1 - cudatoolkit=11.1

View File

@ -5,6 +5,7 @@ channels:
dependencies: dependencies:
- colorama - colorama
- flask-socketio - flask-socketio
- flask-session
- python=3.8.* - python=3.8.*
- eventlet - eventlet
- markdown - markdown

View File

@ -5,6 +5,7 @@ channels:
dependencies: dependencies:
- colorama - colorama
- flask-socketio - flask-socketio
- flask-session
- python=3.8.* - python=3.8.*
- eventlet - eventlet
- markdown - markdown

View File

@ -17,7 +17,7 @@ gensettingstf = [
"id": "settemp", "id": "settemp",
"min": 0.1, "min": 0.1,
"max": 2.0, "max": 2.0,
"step": 0.05, "step": 0.01,
"default": 0.5, "default": 0.5,
"tooltip": "Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious." "tooltip": "Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious."
}, },
@ -28,7 +28,7 @@ gensettingstf = [
"id": "settopp", "id": "settopp",
"min": 0.0, "min": 0.0,
"max": 1.0, "max": 1.0,
"step": 0.05, "step": 0.01,
"default": 0.9, "default": 0.9,
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect)" "tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect)"
}, },
@ -50,7 +50,7 @@ gensettingstf = [
"id": "settfs", "id": "settfs",
"min": 0.0, "min": 0.0,
"max": 1.0, "max": 1.0,
"step": 0.05, "step": 0.01,
"default": 1.0, "default": 1.0,
"tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)" "tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)"
}, },
@ -61,7 +61,7 @@ gensettingstf = [
"id": "settypical", "id": "settypical",
"min": 0.0, "min": 0.0,
"max": 1.0, "max": 1.0,
"step": 0.05, "step": 0.01,
"default": 1.0, "default": 1.0,
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect." "tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
}, },

30
maps/bloom.json Normal file
View File

@ -0,0 +1,30 @@
{
"mtj_compat": "bloom",
"mtj_pe": "alibi",
"mtj_config_map": {
"d_model": "n_embed",
"n_heads": "num_attention_heads",
"layers": "n_layer"
},
"static_weights": {
"word_embeddings.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"word_embeddings_layernorm.weight": {"mtj": {"module": "embedding_shard/~/replicated_layer_norm", "param": "scale"}},
"word_embeddings_layernorm.bias": {"mtj": {"module": "embedding_shard/~/replicated_layer_norm", "param": "offset"}},
"ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}
},
"layer_weights": {
"h.{layer}.self_attention.query_key_value.weight": {"mtj": {"module": "layer_{layer}/~/combined_qkv", "param": "w"}},
"h.{layer}.self_attention.query_key_value.bias": {"mtj": {"module": "layer_{layer}/~/combined_qkv", "param": "b"}},
"h.{layer}.self_attention.dense.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}},
"h.{layer}.self_attention.dense.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}},
"h.{layer}.mlp.dense_h_to_4h.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}},
"h.{layer}.mlp.dense_h_to_4h.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}},
"h.{layer}.mlp.dense_4h_to_h.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}},
"h.{layer}.mlp.dense_4h_to_h.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}},
"h.{layer}.input_layernorm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}},
"h.{layer}.input_layernorm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}},
"h.{layer}.post_attention_layernorm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}},
"h.{layer}.post_attention_layernorm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}}
}
}

View File

@ -5,6 +5,7 @@ requests
optax >= 0.0.5, <= 0.0.9 optax >= 0.0.5, <= 0.0.9
dm-haiku == 0.0.5 dm-haiku == 0.0.5
jax == 0.2.21 jax == 0.2.21
jaxlib >= 0.1.69, <= 0.3.7
transformers >= 4.19 transformers >= 4.19
progressbar2 progressbar2
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck

View File

@ -87,6 +87,7 @@ var wiscroll = 0;
var editmode = false; var editmode = false;
var connected = false; var connected = false;
var newly_loaded = true; var newly_loaded = true;
var all_modified_chunks = new Set();
var modified_chunks = new Set(); var modified_chunks = new Set();
var empty_chunks = new Set(); var empty_chunks = new Set();
var gametext_bound = false; var gametext_bound = false;
@ -129,6 +130,7 @@ var adventure = false;
var chatmode = false; var chatmode = false;
var sliders_throttle = getThrottle(200); var sliders_throttle = getThrottle(200);
var submit_throttle = null;
//=================================================================// //=================================================================//
// METHODS // METHODS
@ -892,6 +894,17 @@ function dosubmit(disallow_abort) {
return; return;
} }
chunkOnFocusOut("override"); chunkOnFocusOut("override");
// Wait for editor changes to be applied before submitting
submit_throttle = getThrottle(70);
submit_throttle.txt = txt;
submit_throttle.disallow_abort = disallow_abort;
submit_throttle(0, _dosubmit);
}
function _dosubmit() {
var txt = submit_throttle.txt;
var disallow_abort = submit_throttle.disallow_abort;
submit_throttle = null;
input_text.val(""); input_text.val("");
hideMessage(); hideMessage();
hidegenseqs(); hidegenseqs();
@ -1523,14 +1536,30 @@ function chunkOnTextInput(event) {
r.deleteContents(); r.deleteContents();
} }
// In Chrome the added <br/> will go outside of the chunks if we press // In Chrome and Safari the added <br/> will go outside of the chunks if we press
// enter at the end of the story in the editor, so this is here // enter at the end of the story in the editor, so this is here
// to put the <br/> back in the right place // to put the <br/> back in the right place
var br = $("#_EDITOR_LINEBREAK_")[0]; var br = $("#_EDITOR_LINEBREAK_")[0];
if(br.parentNode === game_text[0]) { if(br.parentNode === game_text[0]) {
var parent = br.previousSibling;
if(br.previousSibling.nodeType !== 1) { if(br.previousSibling.nodeType !== 1) {
parent = br.previousSibling.previousSibling;
br.previousSibling.previousSibling.appendChild(br.previousSibling); br.previousSibling.previousSibling.appendChild(br.previousSibling);
} }
if(parent.lastChild.tagName === "BR") {
parent.lastChild.remove(); // Chrome and Safari also insert an extra <br/> in this case for some reason so we need to remove it
if(using_webkit_patch) {
// Safari on iOS has a bug where it selects all text in the last chunk of the story when this happens so we collapse the selection to the end of the chunk in that case
setTimeout(function() {
var s = getSelection();
var r = s.getRangeAt(0);
r.selectNodeContents(parent);
r.collapse(false);
s.removeAllRanges();
s.addRange(r);
}, 2);
}
}
br.previousSibling.appendChild(br); br.previousSibling.appendChild(br);
r.selectNodeContents(br.parentNode); r.selectNodeContents(br.parentNode);
s.removeAllRanges(); s.removeAllRanges();
@ -1712,6 +1741,7 @@ function applyChunkDeltas(nodes) {
var chunks = Array.from(buildChunkSetFromNodeArray(nodes)); var chunks = Array.from(buildChunkSetFromNodeArray(nodes));
for(var i = 0; i < chunks.length; i++) { for(var i = 0; i < chunks.length; i++) {
modified_chunks.add(chunks[i]); modified_chunks.add(chunks[i]);
all_modified_chunks.add(chunks[i]);
} }
setTimeout(function() { setTimeout(function() {
var chunks = Array.from(modified_chunks); var chunks = Array.from(modified_chunks);
@ -1722,12 +1752,18 @@ function applyChunkDeltas(nodes) {
if(!selected_chunks.has(chunks[i])) { if(!selected_chunks.has(chunks[i])) {
modified_chunks.delete(chunks[i]); modified_chunks.delete(chunks[i]);
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(chunk)}); socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(chunk)});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
} }
empty_chunks.delete(chunks[i]); empty_chunks.delete(chunks[i]);
} else { } else {
if(!selected_chunks.has(chunks[i])) { if(!selected_chunks.has(chunks[i])) {
modified_chunks.delete(chunks[i]); modified_chunks.delete(chunks[i]);
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(chunk)}); socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(chunk)});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
} }
empty_chunks.add(chunks[i]); empty_chunks.add(chunks[i]);
} }
@ -1749,6 +1785,9 @@ function syncAllModifiedChunks(including_selected_chunks=false) {
empty_chunks.delete(chunks[i]); empty_chunks.delete(chunks[i]);
} }
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': data}); socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': data});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
} }
} }
} }
@ -1801,10 +1840,16 @@ function restorePrompt() {
if(this.innerText.trim().length) { if(this.innerText.trim().length) {
saved_prompt = this.innerText.trim(); saved_prompt = this.innerText.trim();
socket.send({'cmd': 'inlinedelete', 'data': this.getAttribute("n")}); socket.send({'cmd': 'inlinedelete', 'data': this.getAttribute("n")});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
this.parentNode.removeChild(this); this.parentNode.removeChild(this);
return false; return false;
} }
socket.send({'cmd': 'inlinedelete', 'data': this.getAttribute("n")}); socket.send({'cmd': 'inlinedelete', 'data': this.getAttribute("n")});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
this.parentNode.removeChild(this); this.parentNode.removeChild(this);
}); });
} }
@ -1819,6 +1864,9 @@ function restorePrompt() {
modified_chunks.delete('0'); modified_chunks.delete('0');
empty_chunks.delete('0'); empty_chunks.delete('0');
socket.send({'cmd': 'inlineedit', 'chunk': '0', 'data': saved_prompt}); socket.send({'cmd': 'inlineedit', 'chunk': '0', 'data': saved_prompt});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
} }
function deleteEmptyChunks() { function deleteEmptyChunks() {
@ -1830,13 +1878,21 @@ function deleteEmptyChunks() {
restorePrompt(); restorePrompt();
} else { } else {
socket.send({'cmd': 'inlinedelete', 'data': chunks[i]}); socket.send({'cmd': 'inlinedelete', 'data': chunks[i]});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
} }
} }
if(modified_chunks.has('0')) { if(modified_chunks.has('0')) {
modified_chunks.delete(chunks[i]); modified_chunks.delete(chunks[i]);
socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(document.getElementById("n0"))}); socket.send({'cmd': 'inlineedit', 'chunk': chunks[i], 'data': formatChunkInnerText(document.getElementById("n0"))});
if(submit_throttle !== null) {
submit_throttle(0, _dosubmit);
}
}
if(gamestarted) {
saved_prompt = formatChunkInnerText($("#n0")[0]);
} }
saved_prompt = formatChunkInnerText($("#n0")[0]);
} }
function highlightEditingChunks() { function highlightEditingChunks() {
@ -1860,11 +1916,29 @@ function highlightEditingChunks() {
} }
function cleanupChunkWhitespace() { function cleanupChunkWhitespace() {
unbindGametext();
var chunks = Array.from(all_modified_chunks);
for(var i = 0; i < chunks.length; i++) {
var original_chunk = document.getElementById("n" + chunks[i]);
if(original_chunk === null || original_chunk.innerText.trim().length === 0) {
all_modified_chunks.delete(chunks[i]);
modified_chunks.delete(chunks[i]);
empty_chunks.add(chunks[i]);
}
}
// Merge empty chunks with the next chunk // Merge empty chunks with the next chunk
var chunks = Array.from(empty_chunks); var chunks = Array.from(empty_chunks);
chunks.sort(function(e) {parseInt(e)}); chunks.sort(function(e) {parseInt(e)});
for(var i = 0; i < chunks.length; i++) { for(var i = 0; i < chunks.length; i++) {
if(chunks[i] == "0") {
continue;
}
var original_chunk = document.getElementById("n" + chunks[i]); var original_chunk = document.getElementById("n" + chunks[i]);
if(original_chunk === null) {
continue;
}
var chunk = original_chunk.nextSibling; var chunk = original_chunk.nextSibling;
while(chunk) { while(chunk) {
if(chunk.tagName === "CHUNK") { if(chunk.tagName === "CHUNK") {
@ -1874,11 +1948,14 @@ function cleanupChunkWhitespace() {
} }
if(chunk) { if(chunk) {
chunk.innerText = original_chunk.innerText + chunk.innerText; chunk.innerText = original_chunk.innerText + chunk.innerText;
if(original_chunk.innerText.length != 0 && !modified_chunks.has(chunk.getAttribute("n"))) {
modified_chunks.add(chunk.getAttribute("n"));
}
} }
original_chunk.innerText = ""; original_chunk.innerText = "";
} }
// Move whitespace at the end of non-empty chunks into the beginning of the next non-empty chunk // Move whitespace at the end of non-empty chunks into the beginning of the next non-empty chunk
var chunks = Array.from(modified_chunks); var chunks = Array.from(all_modified_chunks);
chunks.sort(function(e) {parseInt(e)}); chunks.sort(function(e) {parseInt(e)});
for(var i = 0; i < chunks.length; i++) { for(var i = 0; i < chunks.length; i++) {
var original_chunk = document.getElementById("n" + chunks[i]); var original_chunk = document.getElementById("n" + chunks[i]);
@ -1892,9 +1969,14 @@ function cleanupChunkWhitespace() {
var ln = original_chunk.innerText.trimEnd().length; var ln = original_chunk.innerText.trimEnd().length;
if (chunk) { if (chunk) {
chunk.innerText = original_chunk.innerText.substring(ln) + chunk.innerText; chunk.innerText = original_chunk.innerText.substring(ln) + chunk.innerText;
if(ln != original_chunk.innerText.length && !modified_chunks.has(chunk.getAttribute("n"))) {
modified_chunks.add(chunk.getAttribute("n"));
}
} }
original_chunk.innerText = original_chunk.innerText.substring(0, ln); original_chunk.innerText = original_chunk.innerText.substring(0, ln);
} }
bindGametext();
} }
// This gets run every time the text in a chunk is edited // This gets run every time the text in a chunk is edited
@ -1976,6 +2058,7 @@ function chunkOnFocusOut(event) {
return; return;
} }
cleanupChunkWhitespace(); cleanupChunkWhitespace();
all_modified_chunks = new Set();
syncAllModifiedChunks(true); syncAllModifiedChunks(true);
setTimeout(function() { setTimeout(function() {
var blurred = game_text[0] !== document.activeElement; var blurred = game_text[0] !== document.activeElement;
@ -2185,6 +2268,7 @@ $(document).ready(function(){
unbindGametext(); unbindGametext();
allowedit = gamestarted && $("#allowediting").prop('checked'); allowedit = gamestarted && $("#allowediting").prop('checked');
game_text.attr('contenteditable', allowedit); game_text.attr('contenteditable', allowedit);
all_modified_chunks = new Set();
modified_chunks = new Set(); modified_chunks = new Set();
empty_chunks = new Set(); empty_chunks = new Set();
game_text.html(msg.data); game_text.html(msg.data);
@ -2739,6 +2823,12 @@ $(document).ready(function(){
chunkOnFocusOut chunkOnFocusOut
); );
mutation_observer = new MutationObserver(chunkOnDOMMutate); mutation_observer = new MutationObserver(chunkOnDOMMutate);
$("#gamescreen").on('click', function(e) {
if(this !== e.target) {
return;
}
document.activeElement.blur();
});
// This is required for the editor to work correctly in Firefox on desktop // This is required for the editor to work correctly in Firefox on desktop
// because the gods of HTML and JavaScript say so // because the gods of HTML and JavaScript say so

View File

@ -52,7 +52,7 @@ import pickle
import torch import torch
import utils import utils
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@ -73,7 +73,7 @@ STORAGE_TYPE_MAP = {
class LazyTensor: class LazyTensor:
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type self.storage_type = storage_type
self.key = key self.key = key
self.location = location self.location = location

View File

@ -1246,13 +1246,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if utils.num_shards is not None: if utils.num_shards is not None:
utils.current_shard += 1 utils.current_shard += 1
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None)
# Some model weights are used by transformers but not by MTJ. # Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because # We have to materialize these weights anyways because
# transformers will throw a tantrum otherwise. To attain # transformers will throw a tantrum otherwise. To attain
# the least possible memory usage, we create them as meta # the least possible memory usage, we create them as meta
# tensors, which don't take up any actual CPU or TPU memory. # tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec: if model_spec_key is None:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
utils.bar.update(1) utils.bar.update(1)
continue continue
@ -1267,7 +1268,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if current_offset != model_dict[key].seek_offset: if current_offset != model_dict[key].seek_offset:
f.read(model_dict[key].seek_offset - current_offset) f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset current_offset = model_dict[key].seek_offset
spec = model_spec[key] spec = model_spec[model_spec_key]
transforms = set(spec.get("transforms", ())) transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}" error = f"Duplicate key {repr(key)}"

View File

@ -183,8 +183,8 @@ function userscript.genmod()
max_overlap[i] = 0 max_overlap[i] = 0
local s = {} local s = {}
local z = {[0] = 0} local z = {[0] = 0}
local l = 1 local l = 0
local r = 1 local r = 0
local n_s = math.min(n_tokens, bias_entry.n_tokens) local n_s = math.min(n_tokens, bias_entry.n_tokens)
local j = 0 local j = 0
for k = 1, n_s do for k = 1, n_s do