mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Pin driver to the one from JAX 0.3.25
This commit is contained in:
@@ -1095,7 +1095,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
|||||||
|
|
||||||
koboldai_vars.status_message = ""
|
koboldai_vars.status_message = ""
|
||||||
|
|
||||||
def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||||
|
|
||||||
if "pad_token_id" in kwargs:
|
if "pad_token_id" in kwargs:
|
||||||
@@ -1504,4 +1504,4 @@ def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=Fal
|
|||||||
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||||
global shard_xmap, batch_xmap
|
global shard_xmap, batch_xmap
|
||||||
shard_xmap = __shard_xmap()
|
shard_xmap = __shard_xmap()
|
||||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||||
|
Reference in New Issue
Block a user