Merge pull request #126 from VE-FORBRYDERNE/opt

Update OPT models and fix 20B model on TPU
This commit is contained in:
henk717 2022-06-21 23:19:03 +02:00 committed by GitHub
commit 2be1f5088f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 0 deletions

View File

@ -12,6 +12,8 @@
"decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, "decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
"decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}}, "decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}},
"decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}}, "decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}},
"decoder.final_layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
"decoder.final_layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
"decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}} "decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}}
}, },
"layer_weights": { "layer_weights": {

View File

@ -1119,6 +1119,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
return old_encode(s).ids return old_encode(s).ids
return encode return encode
tokenizer.encode = new_encode(tokenizer.encode) tokenizer.encode = new_encode(tokenizer.encode)
tokenizer._koboldai_header = []
elif not hf_checkpoint: elif not hf_checkpoint:
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")