diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index db31b902..458a67bb 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -547,7 +547,7 @@ class PenalizingCausalTransformer(CausalTransformer): compiling_callback() numseqs = numseqs_aux.shape[0] # These are the tokens that we don't want the AI to ever write - self.badwords = jnp.array(vars.badwordsids).squeeze() + badwords = jnp.array(vars.badwordsids).squeeze() @hk.transform def generate_sample(context, ctx_length): # Give the initial context to the transformer @@ -605,7 +605,7 @@ class PenalizingCausalTransformer(CausalTransformer): # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero - logits = logits.at[self.badwords].set(-jnp.inf) + logits = logits.at[badwords].set(-jnp.inf) # Use the sampler (kobold_sample_static) to pick one token # based on the logits array as a 0D uint32 array # (higher logit means higher probability of being @@ -1140,9 +1140,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tpu_address = tpu_address.replace("grpc://", "") tpu_address_without_port = tpu_address.split(':', 1)[0] url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}' + requests.post(url) config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + tpu_address - requests.post(url) spinner.terminate() print()