mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Allow TPU models to specify settings/config in config.json
This commit is contained in:
@ -791,12 +791,24 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||
"pe_rotary_dims": 64,
|
||||
"seq": 2048,
|
||||
"cores_per_replica": 8,
|
||||
"tokenizer_class": "GPT2TokenizerFast",
|
||||
"tokenizer": "gpt2",
|
||||
}
|
||||
params = kwargs
|
||||
if "compat" in params:
|
||||
default_params["compat"] = params["compat"]
|
||||
if default_params["compat"] == "fairseq_lm":
|
||||
default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
|
||||
for param in default_params:
|
||||
if param not in params:
|
||||
params[param] = default_params[param]
|
||||
|
||||
# Load tokenizer
|
||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
|
||||
|
||||
# Disable JAX warnings about these two functions having been renamed
|
||||
jax.host_count = jax.process_count
|
||||
jax.host_id = jax.process_index
|
||||
@ -819,7 +831,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||
|
||||
global shard_xmap, batch_xmap
|
||||
shard_xmap = __shard_xmap()
|
||||
|
Reference in New Issue
Block a user