Commit Graph

55 Commits

Author SHA1 Message Date
Gnome Ann 0ea4fa9c87 Automatically calculate badwords and pad_token_id 2022-06-21 14:35:52 -04:00
Gnome Ann ea7d278ff4 Fix 20B TPU model 2022-06-21 13:16:45 -04:00
Gnome Ann 5e71f7fe97 Use slow tokenizer if fast tokenizer is not available 2022-06-17 21:08:37 -04:00
Gnome Ann 2d3db7b4ba Implement support for sampler order in the backend code 2022-06-13 19:12:23 -04:00
Gnome Ann fdb2a7fa4c Top-A sampling 2022-06-10 22:28:20 -04:00
Gnome Ann 707316de31 Kaggle TPU support 2022-05-31 12:20:16 -04:00
Gnome Ann d4e8f56789 Remove debugging code from tpu_mtj_backend.py 2022-05-14 12:00:44 -04:00
Gnome Ann 0c5ca5261e Loading a sharded model will now display only one progress bar 2022-05-13 23:32:16 -04:00
Gnome Ann b1d8797a54 Allow TPU Colab to load sharded HF models 2022-05-12 23:51:40 -04:00
Gnome Ann 4fa5f1cd6a Add TPU support for OPT-350M
The 350M model seems to have a different structure than the other ones ???
2022-05-12 22:21:15 -04:00
Gnome Ann f5e689a725 Upload maps/opt.json and update requirements 2022-05-12 19:09:31 -04:00
Gnome Ann b97b2a02d6 Add `--revision` command line flag 2022-05-10 22:14:56 -04:00
Gnome Ann c117bfd0ad Fix lazy loader 2022-04-08 19:38:15 -04:00
Gnome Ann fabbdf2bb1 Lazy loader Python 3.6 compatibility
The current lazy loader relies on a feature of the Python zipfile module
that was added in Python 3.7.0:

https://bugs.python.org/issue22908

This commit adds compatibility for Python 3.6.
2022-04-02 15:02:54 -04:00
Gnome Ann 67e28d2b5c Typical sampling needs to use nansum instead of sum
If `probs` is zero then `log_probs` will be negative infinity, and the
calculation of `neg_entropy` would then give NaN because zero times
infinity is a mathematically indeterminate value.

We need to use nansum so that those NaN values are treated as zeros to
ignore them in the entropy calculation.
2022-03-28 00:02:31 -04:00
Gnome Ann d5989d4c62 Hide division by zero warning in JAX typical filter
This warning happens when `np.log` gets an input containing zeros.
In that case, NumPy will throw a warning and output negative infinity.

Negative infinity is the correct behaviour here, so we can safely ignore
the warning.
2022-03-27 16:57:12 -04:00
Gnome Ann 20e48b11d7 Typical sampling 2022-03-27 16:25:50 -04:00
Gnome Ann 73aecc0510 Divide NeoX replicated bias layers by 4 again instead of by 8 2022-03-20 01:04:55 -04:00
Gnome Ann 05fc46b253 Changing this again to divide by 8 2022-03-19 02:09:41 -04:00
Gnome Ann 6c20d0d657 Nevermind, dividing by 4 is actually correct... 2022-03-19 00:55:04 -04:00
Gnome Ann f16b61ec77 Should divide NeoX replicated parameters by 8 (not by 4)
Also, suppresses the PyTorch 1.11 warning about transposing tensors with
ndim != 2 in the new code
2022-03-19 00:48:33 -04:00
Gnome Ann c2c139e940 Change default PE type for NeoX to `neox_rotary` 2022-03-19 00:26:04 -04:00
Gnome Ann 85a4959efa Merge branch 'united' into neox 2022-03-18 11:19:03 -04:00
Gnome Ann c444260eac Silence PyTorch warning about transposing tensors with dimension != 2 2022-03-17 15:16:56 -04:00
Gnome Ann eaf190469d Add PyTorch 1.11 support for lazy loader 2022-03-17 12:51:41 -04:00
Gnome Ann 95c4251db9 Print two newlines before loading HF models 2022-03-15 13:58:53 -04:00
Gnome Ann 9e2848e48f Show parameter count when loading GPT-NeoX in Colab TPU instance 2022-03-15 13:55:27 -04:00
Gnome Ann 88f247d535 GPT-NeoX-20B support in Colab TPU instances 2022-03-14 23:14:20 -04:00
Gnome Ann 2b8c46338e Change current working directory to KoboldAI folder 2022-03-13 01:22:11 -05:00
Gnome Ann 48d07adb54 Also fallback to generic GPT2 tokenizer in Colab TPU instances 2022-03-12 23:19:35 -05:00
Gnome Ann a99eb8724d Use DLPack to convert PyTorch tensors to JAX arrays 2022-03-10 15:12:42 -05:00
henk717 68281184bf Remove Lowmem from TPU 2022-03-09 19:21:15 +01:00
Gnome Ann 0a258a6282 Support for loading HF models on TPU with `--colab_tpu` 2022-03-05 12:33:33 -05:00
Gnome Ann ad10ac8871 Allow TPU models to specify settings/config in config.json 2022-02-23 18:22:18 -05:00
Gnome Ann 7ec549c726 Use dematerialized loading in TPU backend for lower device memory usage 2022-02-22 19:43:13 -05:00
henk717 fca7f8659f Badwords unification
TPU's no longer use hardcoded badwords but instead use the var
2022-01-29 18:09:53 +01:00
Gnome Ann 3f18888eec Repetition penalty slope and range 2022-01-24 15:30:38 -05:00
Gnome Ann 3ba0e3f9d9 Dynamic TPU backend should support dynamic warpers and abort button 2022-01-17 14:10:32 -05:00
Gnome Ann 31735c4239 Fix np.take (https://github.com/google/jax/issues/3774) 2022-01-17 13:54:02 -05:00
Gnome Ann 33f9f2dc82 Show message when TPU backend is compiling 2022-01-16 21:09:10 -05:00
Gnome Ann f4eb896a69 Use original TPU backend if possible 2022-01-15 23:31:07 -05:00
Gnome Ann e0fdce2cc6 Fix TPU generation modifier 2022-01-14 23:00:06 -05:00
Gnome Ann 932c393d6a Add TPU support for dynamic WI scan and generation modifiers 2022-01-14 21:39:02 -05:00
Gnome Ann 0bef92419b Convert the `jit`ted function into ordinary NumPy operations 2022-01-14 15:05:21 -05:00
Gnome Ann 57a6886007 Move sampling into a `jax.jit`ted function 2022-01-14 02:23:19 -05:00
Gnome Ann 09c4fdcb2e Split `generate_xmap` into two xmaps 2022-01-13 00:56:00 -05:00
Gnome Ann a3d6dc93e8 xmaps for moving things onto TPU 2022-01-12 21:45:30 -05:00
Gnome Ann 8742453f95 Add safeguards for token budget and text formatting
* Error messages are now shown when memory, author's note, etc. exceeds
  budget by itself
* Formatting options no longer break if there are empty chunks in the
  story (although there shouldn't be any in the first place)
* Number of generated tokens is now kept track of from Python
2021-12-26 18:29:54 -05:00
Gnome Ann fbf3e7615b Add API for generated tokens and output text 2021-12-12 19:27:20 -05:00
Gnome Ann d2d338d314 Improve TPU backend compilation times with `numseqs > 1`
A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
2021-11-30 19:22:40 -05:00