Merge pull request #127 from VE-FORBRYDERNE/tracer

Fix JAX UnexpectedTracerError
This commit is contained in:
henk717 2022-06-23 19:29:51 +02:00 committed by GitHub
commit d94f29a68a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -547,7 +547,7 @@ class PenalizingCausalTransformer(CausalTransformer):
compiling_callback() compiling_callback()
numseqs = numseqs_aux.shape[0] numseqs = numseqs_aux.shape[0]
# These are the tokens that we don't want the AI to ever write # These are the tokens that we don't want the AI to ever write
self.badwords = jnp.array(vars.badwordsids).squeeze() badwords = jnp.array(vars.badwordsids).squeeze()
@hk.transform @hk.transform
def generate_sample(context, ctx_length): def generate_sample(context, ctx_length):
# Give the initial context to the transformer # Give the initial context to the transformer
@ -605,7 +605,7 @@ class PenalizingCausalTransformer(CausalTransformer):
# Remove any tokens in the badwords list by setting # Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively # their logits to negative infinity which effectively
# makes their probabilities of being chosen zero # makes their probabilities of being chosen zero
logits = logits.at[self.badwords].set(-jnp.inf) logits = logits.at[badwords].set(-jnp.inf)
# Use the sampler (kobold_sample_static) to pick one token # Use the sampler (kobold_sample_static) to pick one token
# based on the logits array as a 0D uint32 array # based on the logits array as a 0D uint32 array
# (higher logit means higher probability of being # (higher logit means higher probability of being
@ -1140,9 +1140,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
tpu_address = tpu_address.replace("grpc://", "") tpu_address = tpu_address.replace("grpc://", "")
tpu_address_without_port = tpu_address.split(':', 1)[0] tpu_address_without_port = tpu_address.split(':', 1)[0]
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}' url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
requests.post(url)
config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address config.FLAGS.jax_backend_target = "grpc://" + tpu_address
requests.post(url)
spinner.terminate() spinner.terminate()
print() print()