From 150ce033c93a076ad5cb0b6b0aa05d1574821b63 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 5 Dec 2021 02:49:15 -0500 Subject: [PATCH] TPU backend no longer needs to recompile after changing softprompt --- aiserver.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/aiserver.py b/aiserver.py index 28da1d21..6b41d72c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1802,12 +1802,25 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") soft_tokens = None - if(vars.sp is not None): - soft_tokens = np.arange( - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, - dtype=np.uint32 + if(vars.sp is None): + global np + if 'np' not in globals(): + import numpy as np + tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32) + rows = tensor.shape[0] + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows + tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) + tensor = tensor.reshape( + tpu_mtj_backend.params["cores_per_replica"], + -1, + tpu_mtj_backend.params["d_model"], ) + vars.sp = tensor + soft_tokens = np.arange( + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, + dtype=np.uint32 + ) genout = tpu_mtj_backend.infer( txt, @@ -2676,7 +2689,7 @@ def spRequest(filename): if(vars.model in ("TPUMeshTransformerGPTJ",)): rows = tensor.shape[0] - padding_amount = -(rows % -tpu_mtj_backend.params["cores_per_replica"]) + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = tensor.reshape( tpu_mtj_backend.params["cores_per_replica"],