Merge pull request #127 from VE-FORBRYDERNE/tracer
Fix JAX UnexpectedTracerError
This commit is contained in:
commit
d94f29a68a
|
@ -547,7 +547,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
compiling_callback()
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
# These are the tokens that we don't want the AI to ever write
|
||||
self.badwords = jnp.array(vars.badwordsids).squeeze()
|
||||
badwords = jnp.array(vars.badwordsids).squeeze()
|
||||
@hk.transform
|
||||
def generate_sample(context, ctx_length):
|
||||
# Give the initial context to the transformer
|
||||
|
@ -605,7 +605,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
logits = logits.at[self.badwords].set(-jnp.inf)
|
||||
logits = logits.at[badwords].set(-jnp.inf)
|
||||
# Use the sampler (kobold_sample_static) to pick one token
|
||||
# based on the logits array as a 0D uint32 array
|
||||
# (higher logit means higher probability of being
|
||||
|
@ -1140,9 +1140,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
tpu_address = tpu_address.replace("grpc://", "")
|
||||
tpu_address_without_port = tpu_address.split(':', 1)[0]
|
||||
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
|
||||
requests.post(url)
|
||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||
requests.post(url)
|
||||
spinner.terminate()
|
||||
print()
|
||||
|
||||
|
|
Loading…
Reference in New Issue