mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix: TPU driver error
to_dlpack/from_dlpack was causing issues with tensor with new jax version
This commit is contained in:
@@ -2,9 +2,9 @@ torch >= 1.9, < 1.13
|
||||
numpy
|
||||
tqdm
|
||||
requests
|
||||
dm-haiku == 0.0.5
|
||||
jax == 0.2.21
|
||||
jaxlib >= 0.1.69, <= 0.3.7
|
||||
dm-haiku==0.0.9
|
||||
jax==0.3.25
|
||||
jaxlib==0.3.25
|
||||
transformers == 4.28.0
|
||||
chex == 0.1.5
|
||||
huggingface_hub==0.12.1
|
||||
|
@@ -1095,7 +1095,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
|
||||
koboldai_vars.status_message = ""
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
||||
def load_model(path: str, driver_version="tpu_driver_nightly", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||
|
||||
if "pad_token_id" in kwargs:
|
||||
@@ -1270,11 +1270,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {koboldai_vars.cloudflare_link}")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {koboldai_vars.cloudflare_link}/new_ui")
|
||||
|
||||
|
||||
global shard_xmap, batch_xmap
|
||||
shard_xmap = __shard_xmap()
|
||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||
|
||||
global badwords
|
||||
# These are the tokens that we don't want the AI to ever write
|
||||
badwords = jnp.array(koboldai_vars.badwordsids).squeeze()
|
||||
@@ -1401,19 +1396,20 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
#if "no_transpose" not in transforms and tensor.ndim == 2:
|
||||
# tensor = tensor.T
|
||||
tensor.unsqueeze_(0)
|
||||
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
||||
tensor = tensor.bfloat16()
|
||||
|
||||
|
||||
# Shard the tensor so that parts of the tensor can be used
|
||||
# on different TPU cores
|
||||
tensor = reshard_reverse(
|
||||
tensor,
|
||||
params["cores_per_replica"],
|
||||
network.state["params"][spec["module"]][spec["param"]].shape,
|
||||
)
|
||||
tensor = jnp.array(tensor.detach())
|
||||
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
||||
tensor = tensor.bfloat16()
|
||||
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
||||
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(
|
||||
reshard_reverse(
|
||||
tensor,
|
||||
params["cores_per_replica"],
|
||||
network.state["params"][spec["module"]][spec["param"]].shape,
|
||||
)
|
||||
)).copy(),
|
||||
tensor,
|
||||
np.empty(params["cores_per_replica"]),
|
||||
)
|
||||
|
||||
@@ -1506,3 +1502,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
model = GPTNeoForCausalLM.from_pretrained(koboldai_vars.model, revision=koboldai_vars.revision, cache_dir="cache")
|
||||
|
||||
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||
global shard_xmap, batch_xmap
|
||||
shard_xmap = __shard_xmap()
|
||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
Reference in New Issue
Block a user