Allow TPU models to specify settings/config in config.json

This commit is contained in:
Gnome Ann
2022-02-23 18:22:18 -05:00
parent 6151d16df0
commit ad10ac8871
2 changed files with 27 additions and 7 deletions

View File

@ -791,12 +791,24 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
"pe_rotary_dims": 64,
"seq": 2048,
"cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer": "gpt2",
}
params = kwargs
if "compat" in params:
default_params["compat"] = params["compat"]
if default_params["compat"] == "fairseq_lm":
default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
for param in default_params:
if param not in params:
params[param] = default_params[param]
# Load tokenizer
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
# Disable JAX warnings about these two functions having been renamed
jax.host_count = jax.process_count
jax.host_id = jax.process_index
@ -819,7 +831,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()