mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-03-11 00:50:09 +01:00
commit
dddde7dbc3
@ -2,9 +2,10 @@ torch >= 1.9, < 1.13
|
|||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
requests
|
requests
|
||||||
dm-haiku == 0.0.5
|
dm-haiku==0.0.9
|
||||||
jax == 0.2.21
|
jax==0.3.25
|
||||||
jaxlib >= 0.1.69, <= 0.3.7
|
jaxlib==0.3.25
|
||||||
|
transformers == 4.28.0
|
||||||
chex == 0.1.5
|
chex == 0.1.5
|
||||||
transformers == 4.24.0
|
transformers == 4.24.0
|
||||||
huggingface_hub==0.12.1
|
huggingface_hub==0.12.1
|
||||||
|
@ -1049,7 +1049,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
|||||||
raise RuntimeError(error)
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
|
||||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||||
|
|
||||||
if "pad_token_id" in kwargs:
|
if "pad_token_id" in kwargs:
|
||||||
@ -1195,6 +1195,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
|
|
||||||
|
|
||||||
global shard_xmap, batch_xmap
|
global shard_xmap, batch_xmap
|
||||||
shard_xmap = __shard_xmap()
|
shard_xmap = __shard_xmap()
|
||||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||||
@ -1244,6 +1245,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
|
||||||
def callback(model_dict, f, **_):
|
def callback(model_dict, f, **_):
|
||||||
if callback.nested:
|
if callback.nested:
|
||||||
return
|
return
|
||||||
@ -1317,19 +1319,20 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
#if "no_transpose" not in transforms and tensor.ndim == 2:
|
#if "no_transpose" not in transforms and tensor.ndim == 2:
|
||||||
# tensor = tensor.T
|
# tensor = tensor.T
|
||||||
tensor.unsqueeze_(0)
|
tensor.unsqueeze_(0)
|
||||||
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
|
||||||
tensor = tensor.bfloat16()
|
|
||||||
|
|
||||||
# Shard the tensor so that parts of the tensor can be used
|
# Shard the tensor so that parts of the tensor can be used
|
||||||
# on different TPU cores
|
# on different TPU cores
|
||||||
|
tensor = reshard_reverse(
|
||||||
|
tensor,
|
||||||
|
params["cores_per_replica"],
|
||||||
|
network.state["params"][spec["module"]][spec["param"]].shape,
|
||||||
|
)
|
||||||
|
tensor = jnp.array(tensor.detach())
|
||||||
|
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
||||||
|
tensor = tensor.bfloat16()
|
||||||
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
||||||
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(
|
tensor,
|
||||||
reshard_reverse(
|
|
||||||
tensor,
|
|
||||||
params["cores_per_replica"],
|
|
||||||
network.state["params"][spec["module"]][spec["param"]].shape,
|
|
||||||
)
|
|
||||||
)).copy(),
|
|
||||||
np.empty(params["cores_per_replica"]),
|
np.empty(params["cores_per_replica"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1416,3 +1419,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||||
|
|
||||||
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||||
|
global shard_xmap, batch_xmap
|
||||||
|
shard_xmap = __shard_xmap()
|
||||||
|
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
Loading…
x
Reference in New Issue
Block a user