Fix for Colab

This commit is contained in:
ebolam 2022-06-06 21:29:14 -04:00
parent edbf36a632
commit df76bc4b41
1 changed files with 25 additions and 23 deletions

View File

@ -931,7 +931,31 @@ def general_startup():
#==================================================================#
# Load Model
#==================================================================#
#==================================================================#
def tpumtjgetsofttokens():
soft_tokens = None
if(vars.sp is None):
global np
if 'np' not in globals():
import numpy as np
tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32)
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
)
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
dtype=np.uint32
)
return soft_tokens
def get_model_info(model, directory=""):
# if the model is in the api list
key = False
@ -1816,28 +1840,6 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
def tpumtjgetsofttokens():
soft_tokens = None
if(vars.sp is None):
global np
if 'np' not in globals():
import numpy as np
tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32)
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
)
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
dtype=np.uint32
)
return soft_tokens
def tpumtjgenerate_warper_callback(scores) -> "np.array":
scores_shape = scores.shape