Merge pull request #126 from VE-FORBRYDERNE/opt
Update OPT models and fix 20B model on TPU
This commit is contained in:
commit
2be1f5088f
|
@ -12,6 +12,8 @@
|
||||||
"decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
|
"decoder.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}},
|
||||||
"decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}},
|
"decoder.project_in.weight": {"mtj": {"module": "embedding_shard", "param": "project_in"}},
|
||||||
"decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}},
|
"decoder.embed_positions.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose", "remove_first_two_rows"]}},
|
||||||
|
"decoder.final_layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}},
|
||||||
|
"decoder.final_layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}},
|
||||||
"decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}}
|
"decoder.project_out.weight": {"mtj": {"module": "projection_shard", "param": "project_out"}}
|
||||||
},
|
},
|
||||||
"layer_weights": {
|
"layer_weights": {
|
||||||
|
|
|
@ -1119,6 +1119,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
return old_encode(s).ids
|
return old_encode(s).ids
|
||||||
return encode
|
return encode
|
||||||
tokenizer.encode = new_encode(tokenizer.encode)
|
tokenizer.encode = new_encode(tokenizer.encode)
|
||||||
|
tokenizer._koboldai_header = []
|
||||||
elif not hf_checkpoint:
|
elif not hf_checkpoint:
|
||||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||||
|
|
Loading…
Reference in New Issue