diff --git a/aiserver.py b/aiserver.py index e4beaad5..c31485f1 100644 --- a/aiserver.py +++ b/aiserver.py @@ -931,7 +931,31 @@ def general_startup(): #==================================================================# # Load Model -#==================================================================# +#==================================================================# + +def tpumtjgetsofttokens(): + soft_tokens = None + if(vars.sp is None): + global np + if 'np' not in globals(): + import numpy as np + tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32) + rows = tensor.shape[0] + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows + tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) + tensor = tensor.reshape( + tpu_mtj_backend.params["cores_per_replica"], + -1, + tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]), + ) + vars.sp = tpu_mtj_backend.shard_xmap(tensor) + soft_tokens = np.arange( + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, + dtype=np.uint32 + ) + return soft_tokens + def get_model_info(model, directory=""): # if the model is in the api list key = False @@ -1816,28 +1840,6 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files - def tpumtjgetsofttokens(): - soft_tokens = None - if(vars.sp is None): - global np - if 'np' not in globals(): - import numpy as np - tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32) - rows = tensor.shape[0] - padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows - tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) - tensor = tensor.reshape( - tpu_mtj_backend.params["cores_per_replica"], - -1, - tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]), - ) - vars.sp = tpu_mtj_backend.shard_xmap(tensor) - soft_tokens = np.arange( - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, - dtype=np.uint32 - ) - return soft_tokens def tpumtjgenerate_warper_callback(scores) -> "np.array": scores_shape = scores.shape