Commit Graph

6 Commits

Author SHA1 Message Date
Gnome Ann d2d338d314 Improve TPU backend compilation times with `numseqs > 1`
A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
2021-11-30 19:22:40 -05:00
Gnome Ann c1e7c1643f Fix unbound axis error in tpu_mtj_backend.py when `numseqs > 1` 2021-11-30 14:06:46 -05:00
Gnome Ann 3c349e6aaf Modify TPU backend code to support JAX 0.2.21
The original one supported versions of JAX up to 0.2.12, and possibly also some
earlier versions. This new code supports exclusively JAX 0.2.21 and does not
work with any earlier or later versions of JAX. However, this new code benefits
from not needing to recompile when changing "Amount To Generate" and also from
supporting stopping generation early, which makes an implementation of Dynamic
World Info Scan finally possible.
2021-11-30 10:13:02 -05:00
Gnome Ann 691febacd6 Fix a typo in tpu_mtj_backend.py 2021-11-22 12:53:19 -05:00
Gnome Ann e068aa9f26 Add soft prompt support to TPU backend 2021-11-21 18:08:04 -05:00
Gnome Ann a65c4de840 Integrate TPU backend
This commit puts the TPU backend code directly in to the KoboldAI code
to make it easier to modify.
2021-11-19 18:06:57 -05:00