This commit is contained in:
ebolam
2022-06-06 21:37:35 -04:00
parent df76bc4b41
commit ae1aed0916

View File

@ -934,27 +934,28 @@ def general_startup():
#==================================================================# #==================================================================#
def tpumtjgetsofttokens(): def tpumtjgetsofttokens():
soft_tokens = None import tpu_mtj_backend
if(vars.sp is None): soft_tokens = None
global np if(vars.sp is None):
if 'np' not in globals(): global np
import numpy as np if 'np' not in globals():
tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32) import numpy as np
rows = tensor.shape[0] tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32)
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows rows = tensor.shape[0]
tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = tensor.reshape( tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tpu_mtj_backend.params["cores_per_replica"], tensor = tensor.reshape(
-1, tpu_mtj_backend.params["cores_per_replica"],
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]), -1,
) tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
vars.sp = tpu_mtj_backend.shard_xmap(tensor) )
soft_tokens = np.arange( vars.sp = tpu_mtj_backend.shard_xmap(tensor)
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
dtype=np.uint32 tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
) dtype=np.uint32
return soft_tokens )
return soft_tokens
def get_model_info(model, directory=""): def get_model_info(model, directory=""):
# if the model is in the api list # if the model is in the api list