Merge pull request #127 from VE-FORBRYDERNE/tracer
Fix JAX UnexpectedTracerError
This commit is contained in:
commit
d94f29a68a
|
@ -547,7 +547,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
compiling_callback()
|
compiling_callback()
|
||||||
numseqs = numseqs_aux.shape[0]
|
numseqs = numseqs_aux.shape[0]
|
||||||
# These are the tokens that we don't want the AI to ever write
|
# These are the tokens that we don't want the AI to ever write
|
||||||
self.badwords = jnp.array(vars.badwordsids).squeeze()
|
badwords = jnp.array(vars.badwordsids).squeeze()
|
||||||
@hk.transform
|
@hk.transform
|
||||||
def generate_sample(context, ctx_length):
|
def generate_sample(context, ctx_length):
|
||||||
# Give the initial context to the transformer
|
# Give the initial context to the transformer
|
||||||
|
@ -605,7 +605,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
# Remove any tokens in the badwords list by setting
|
# Remove any tokens in the badwords list by setting
|
||||||
# their logits to negative infinity which effectively
|
# their logits to negative infinity which effectively
|
||||||
# makes their probabilities of being chosen zero
|
# makes their probabilities of being chosen zero
|
||||||
logits = logits.at[self.badwords].set(-jnp.inf)
|
logits = logits.at[badwords].set(-jnp.inf)
|
||||||
# Use the sampler (kobold_sample_static) to pick one token
|
# Use the sampler (kobold_sample_static) to pick one token
|
||||||
# based on the logits array as a 0D uint32 array
|
# based on the logits array as a 0D uint32 array
|
||||||
# (higher logit means higher probability of being
|
# (higher logit means higher probability of being
|
||||||
|
@ -1140,9 +1140,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
tpu_address = tpu_address.replace("grpc://", "")
|
tpu_address = tpu_address.replace("grpc://", "")
|
||||||
tpu_address_without_port = tpu_address.split(':', 1)[0]
|
tpu_address_without_port = tpu_address.split(':', 1)[0]
|
||||||
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
|
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
|
||||||
|
requests.post(url)
|
||||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||||
requests.post(url)
|
|
||||||
spinner.terminate()
|
spinner.terminate()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue