Commit Graph

18 Commits

Author SHA1 Message Date
Gnome Ann
3ba0e3f9d9 Dynamic TPU backend should support dynamic warpers and abort button 2022-01-17 14:10:32 -05:00
Gnome Ann
31735c4239 Fix np.take (https://github.com/google/jax/issues/3774) 2022-01-17 13:54:02 -05:00
Gnome Ann
33f9f2dc82 Show message when TPU backend is compiling 2022-01-16 21:09:10 -05:00
Gnome Ann
f4eb896a69 Use original TPU backend if possible 2022-01-15 23:31:07 -05:00
Gnome Ann
e0fdce2cc6 Fix TPU generation modifier 2022-01-14 23:00:06 -05:00
Gnome Ann
932c393d6a Add TPU support for dynamic WI scan and generation modifiers 2022-01-14 21:39:02 -05:00
Gnome Ann
0bef92419b Convert the jitted function into ordinary NumPy operations 2022-01-14 15:05:21 -05:00
Gnome Ann
57a6886007 Move sampling into a jax.jitted function 2022-01-14 02:23:19 -05:00
Gnome Ann
09c4fdcb2e Split generate_xmap into two xmaps 2022-01-13 00:56:00 -05:00
Gnome Ann
a3d6dc93e8 xmaps for moving things onto TPU 2022-01-12 21:45:30 -05:00
Gnome Ann
8742453f95 Add safeguards for token budget and text formatting
* Error messages are now shown when memory, author's note, etc. exceeds
  budget by itself
* Formatting options no longer break if there are empty chunks in the
  story (although there shouldn't be any in the first place)
* Number of generated tokens is now kept track of from Python
2021-12-26 18:29:54 -05:00
Gnome Ann
fbf3e7615b Add API for generated tokens and output text 2021-12-12 19:27:20 -05:00
Gnome Ann
d2d338d314 Improve TPU backend compilation times with numseqs > 1
A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
2021-11-30 19:22:40 -05:00
Gnome Ann
c1e7c1643f Fix unbound axis error in tpu_mtj_backend.py when numseqs > 1 2021-11-30 14:06:46 -05:00
Gnome Ann
3c349e6aaf Modify TPU backend code to support JAX 0.2.21
The original one supported versions of JAX up to 0.2.12, and possibly also some
earlier versions. This new code supports exclusively JAX 0.2.21 and does not
work with any earlier or later versions of JAX. However, this new code benefits
from not needing to recompile when changing "Amount To Generate" and also from
supporting stopping generation early, which makes an implementation of Dynamic
World Info Scan finally possible.
2021-11-30 10:13:02 -05:00
Gnome Ann
691febacd6 Fix a typo in tpu_mtj_backend.py 2021-11-22 12:53:19 -05:00
Gnome Ann
e068aa9f26 Add soft prompt support to TPU backend 2021-11-21 18:08:04 -05:00
Gnome Ann
a65c4de840 Integrate TPU backend
This commit puts the TPU backend code directly in to the KoboldAI code
to make it easier to modify.
2021-11-19 18:06:57 -05:00