From f2558e39d911ad1e6827b73077f81f1f812d188c Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 31 May 2022 13:25:41 -0400 Subject: [PATCH 1/2] Fix JAX UnexpectedTracerError --- tpu_mtj_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index fb2dc7ae..80709784 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -519,7 +519,7 @@ class PenalizingCausalTransformer(CausalTransformer): compiling_callback() numseqs = numseqs_aux.shape[0] # These are the tokens that we don't want the AI to ever write - self.badwords = jnp.array(vars.badwordsids).squeeze() + badwords = jnp.array(vars.badwordsids).squeeze() @hk.transform def generate_sample(context, ctx_length): # Give the initial context to the transformer @@ -577,7 +577,7 @@ class PenalizingCausalTransformer(CausalTransformer): # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero - logits = logits.at[self.badwords].set(-jnp.inf) + logits = logits.at[badwords].set(-jnp.inf) # Use the sampler (kobold_sample_static) to pick one token # based on the logits array as a 0D uint32 array # (higher logit means higher probability of being @@ -1101,9 +1101,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tpu_address = tpu_address.replace("grpc://", "") tpu_address_without_port = tpu_address.split(':', 1)[0] url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}' + requests.post(url) config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + tpu_address - requests.post(url) spinner.terminate() print() From 3da885d408a2727bfca5d454b60ca856d4f5b19a Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 23 Jun 2022 15:02:43 -0400 Subject: [PATCH 2/2] GPT-NeoX HF model badwords fix --- aiserver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 4081bd79..ebaa0da5 100644 --- a/aiserver.py +++ b/aiserver.py @@ -976,7 +976,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe if(vars.model_type == "opt"): vars.badwordsids = vars.badwordsids_opt - if(vars.model_type == "neox"): + if(vars.model_type == "gpt_neox"): vars.badwordsids = vars.badwordsids_neox if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): @@ -1886,7 +1886,7 @@ else: import tpu_mtj_backend if(vars.model_type == "opt"): tpu_mtj_backend.pad_token_id = 1 - elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"): + elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "gpt_neox"): tpu_mtj_backend.pad_token_id = 2 tpu_mtj_backend.vars = vars tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback