Fix JAX UnexpectedTracerError

This commit is contained in:
Gnome Ann 2022-05-31 13:25:41 -04:00
parent 243543df13
commit f2558e39d9
1 changed files with 3 additions and 3 deletions

View File

@ -519,7 +519,7 @@ class PenalizingCausalTransformer(CausalTransformer):
compiling_callback()
numseqs = numseqs_aux.shape[0]
# These are the tokens that we don't want the AI to ever write
self.badwords = jnp.array(vars.badwordsids).squeeze()
badwords = jnp.array(vars.badwordsids).squeeze()
@hk.transform
def generate_sample(context, ctx_length):
# Give the initial context to the transformer
@ -577,7 +577,7 @@ class PenalizingCausalTransformer(CausalTransformer):
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
logits = logits.at[self.badwords].set(-jnp.inf)
logits = logits.at[badwords].set(-jnp.inf)
# Use the sampler (kobold_sample_static) to pick one token
# based on the logits array as a 0D uint32 array
# (higher logit means higher probability of being
@ -1101,9 +1101,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
tpu_address = tpu_address.replace("grpc://", "")
tpu_address_without_port = tpu_address.split(':', 1)[0]
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
requests.post(url)
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
requests.post(url)
spinner.terminate()
print()