Commit Graph

50 Commits

Author SHA1 Message Date
Gnome Ann 707316de31 Kaggle TPU support 2022-05-31 12:20:16 -04:00
Gnome Ann d4e8f56789 Remove debugging code from tpu_mtj_backend.py 2022-05-14 12:00:44 -04:00
Gnome Ann 0c5ca5261e Loading a sharded model will now display only one progress bar 2022-05-13 23:32:16 -04:00
Gnome Ann b1d8797a54 Allow TPU Colab to load sharded HF models 2022-05-12 23:51:40 -04:00
Gnome Ann 4fa5f1cd6a Add TPU support for OPT-350M
The 350M model seems to have a different structure than the other ones ???
2022-05-12 22:21:15 -04:00
Gnome Ann f5e689a725 Upload maps/opt.json and update requirements 2022-05-12 19:09:31 -04:00
Gnome Ann b97b2a02d6 Add `--revision` command line flag 2022-05-10 22:14:56 -04:00
Gnome Ann c117bfd0ad Fix lazy loader 2022-04-08 19:38:15 -04:00
Gnome Ann fabbdf2bb1 Lazy loader Python 3.6 compatibility
The current lazy loader relies on a feature of the Python zipfile module
that was added in Python 3.7.0:

https://bugs.python.org/issue22908

This commit adds compatibility for Python 3.6.
2022-04-02 15:02:54 -04:00
Gnome Ann 67e28d2b5c Typical sampling needs to use nansum instead of sum
If `probs` is zero then `log_probs` will be negative infinity, and the
calculation of `neg_entropy` would then give NaN because zero times
infinity is a mathematically indeterminate value.

We need to use nansum so that those NaN values are treated as zeros to
ignore them in the entropy calculation.
2022-03-28 00:02:31 -04:00
Gnome Ann d5989d4c62 Hide division by zero warning in JAX typical filter
This warning happens when `np.log` gets an input containing zeros.
In that case, NumPy will throw a warning and output negative infinity.

Negative infinity is the correct behaviour here, so we can safely ignore
the warning.
2022-03-27 16:57:12 -04:00
Gnome Ann 20e48b11d7 Typical sampling 2022-03-27 16:25:50 -04:00
Gnome Ann 73aecc0510 Divide NeoX replicated bias layers by 4 again instead of by 8 2022-03-20 01:04:55 -04:00
Gnome Ann 05fc46b253 Changing this again to divide by 8 2022-03-19 02:09:41 -04:00
Gnome Ann 6c20d0d657 Nevermind, dividing by 4 is actually correct... 2022-03-19 00:55:04 -04:00
Gnome Ann f16b61ec77 Should divide NeoX replicated parameters by 8 (not by 4)
Also, suppresses the PyTorch 1.11 warning about transposing tensors with
ndim != 2 in the new code
2022-03-19 00:48:33 -04:00
Gnome Ann c2c139e940 Change default PE type for NeoX to `neox_rotary` 2022-03-19 00:26:04 -04:00
Gnome Ann 85a4959efa Merge branch 'united' into neox 2022-03-18 11:19:03 -04:00
Gnome Ann c444260eac Silence PyTorch warning about transposing tensors with dimension != 2 2022-03-17 15:16:56 -04:00
Gnome Ann eaf190469d Add PyTorch 1.11 support for lazy loader 2022-03-17 12:51:41 -04:00
Gnome Ann 95c4251db9 Print two newlines before loading HF models 2022-03-15 13:58:53 -04:00
Gnome Ann 9e2848e48f Show parameter count when loading GPT-NeoX in Colab TPU instance 2022-03-15 13:55:27 -04:00
Gnome Ann 88f247d535 GPT-NeoX-20B support in Colab TPU instances 2022-03-14 23:14:20 -04:00
Gnome Ann 2b8c46338e Change current working directory to KoboldAI folder 2022-03-13 01:22:11 -05:00
Gnome Ann 48d07adb54 Also fallback to generic GPT2 tokenizer in Colab TPU instances 2022-03-12 23:19:35 -05:00
Gnome Ann a99eb8724d Use DLPack to convert PyTorch tensors to JAX arrays 2022-03-10 15:12:42 -05:00
henk717 68281184bf Remove Lowmem from TPU 2022-03-09 19:21:15 +01:00
Gnome Ann 0a258a6282 Support for loading HF models on TPU with `--colab_tpu` 2022-03-05 12:33:33 -05:00
Gnome Ann ad10ac8871 Allow TPU models to specify settings/config in config.json 2022-02-23 18:22:18 -05:00
Gnome Ann 7ec549c726 Use dematerialized loading in TPU backend for lower device memory usage 2022-02-22 19:43:13 -05:00
henk717 fca7f8659f Badwords unification
TPU's no longer use hardcoded badwords but instead use the var
2022-01-29 18:09:53 +01:00
Gnome Ann 3f18888eec Repetition penalty slope and range 2022-01-24 15:30:38 -05:00
Gnome Ann 3ba0e3f9d9 Dynamic TPU backend should support dynamic warpers and abort button 2022-01-17 14:10:32 -05:00
Gnome Ann 31735c4239 Fix np.take (https://github.com/google/jax/issues/3774) 2022-01-17 13:54:02 -05:00
Gnome Ann 33f9f2dc82 Show message when TPU backend is compiling 2022-01-16 21:09:10 -05:00
Gnome Ann f4eb896a69 Use original TPU backend if possible 2022-01-15 23:31:07 -05:00
Gnome Ann e0fdce2cc6 Fix TPU generation modifier 2022-01-14 23:00:06 -05:00
Gnome Ann 932c393d6a Add TPU support for dynamic WI scan and generation modifiers 2022-01-14 21:39:02 -05:00
Gnome Ann 0bef92419b Convert the `jit`ted function into ordinary NumPy operations 2022-01-14 15:05:21 -05:00
Gnome Ann 57a6886007 Move sampling into a `jax.jit`ted function 2022-01-14 02:23:19 -05:00
Gnome Ann 09c4fdcb2e Split `generate_xmap` into two xmaps 2022-01-13 00:56:00 -05:00
Gnome Ann a3d6dc93e8 xmaps for moving things onto TPU 2022-01-12 21:45:30 -05:00
Gnome Ann 8742453f95 Add safeguards for token budget and text formatting
* Error messages are now shown when memory, author's note, etc. exceeds
  budget by itself
* Formatting options no longer break if there are empty chunks in the
  story (although there shouldn't be any in the first place)
* Number of generated tokens is now kept track of from Python
2021-12-26 18:29:54 -05:00
Gnome Ann fbf3e7615b Add API for generated tokens and output text 2021-12-12 19:27:20 -05:00
Gnome Ann d2d338d314 Improve TPU backend compilation times with `numseqs > 1`
A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
2021-11-30 19:22:40 -05:00
Gnome Ann c1e7c1643f Fix unbound axis error in tpu_mtj_backend.py when `numseqs > 1` 2021-11-30 14:06:46 -05:00
Gnome Ann 3c349e6aaf Modify TPU backend code to support JAX 0.2.21
The original one supported versions of JAX up to 0.2.12, and possibly also some
earlier versions. This new code supports exclusively JAX 0.2.21 and does not
work with any earlier or later versions of JAX. However, this new code benefits
from not needing to recompile when changing "Amount To Generate" and also from
supporting stopping generation early, which makes an implementation of Dynamic
World Info Scan finally possible.
2021-11-30 10:13:02 -05:00
Gnome Ann 691febacd6 Fix a typo in tpu_mtj_backend.py 2021-11-22 12:53:19 -05:00
Gnome Ann e068aa9f26 Add soft prompt support to TPU backend 2021-11-21 18:08:04 -05:00
Gnome Ann a65c4de840 Integrate TPU backend
This commit puts the TPU backend code directly in to the KoboldAI code
to make it easier to modify.
2021-11-19 18:06:57 -05:00