mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-09 06:48:23 +01:00
TPU Fix Fix
This commit is contained in:
parent
b4cb09590f
commit
d88f109073
@ -1195,11 +1195,6 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal
|
|||||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
|
|
||||||
|
|
||||||
global shard_xmap, batch_xmap
|
|
||||||
shard_xmap = __shard_xmap()
|
|
||||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
|
||||||
|
|
||||||
global badwords
|
global badwords
|
||||||
# These are the tokens that we don't want the AI to ever write
|
# These are the tokens that we don't want the AI to ever write
|
||||||
badwords = jnp.array(vars.badwordsids).squeeze()
|
badwords = jnp.array(vars.badwordsids).squeeze()
|
||||||
@ -1421,4 +1416,4 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal
|
|||||||
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||||
global shard_xmap, batch_xmap
|
global shard_xmap, batch_xmap
|
||||||
shard_xmap = __shard_xmap()
|
shard_xmap = __shard_xmap()
|
||||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||||
|
Loading…
Reference in New Issue
Block a user