diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index f878d690..205f9cc4 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1195,11 +1195,6 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) maps.thread_resources.env = thread_resources_env - - global shard_xmap, batch_xmap - shard_xmap = __shard_xmap() - batch_xmap = __batch_xmap(shard_dim=cores_per_replica) - global badwords # These are the tokens that we don't want the AI to ever write badwords = jnp.array(vars.badwordsids).squeeze() @@ -1421,4 +1416,4 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal #network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) global shard_xmap, batch_xmap shard_xmap = __shard_xmap() - batch_xmap = __batch_xmap(shard_dim=cores_per_replica) \ No newline at end of file + batch_xmap = __batch_xmap(shard_dim=cores_per_replica)