TPU Fix Fix

This commit is contained in:
henk717 2023-04-23 18:49:25 +02:00 committed by GitHub
parent b4cb09590f
commit d88f109073
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 6 deletions

View File

@ -1195,11 +1195,6 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
global badwords
# These are the tokens that we don't want the AI to ever write
badwords = jnp.array(vars.badwordsids).squeeze()
@ -1421,4 +1416,4 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)