Gnome Ann
0ea4fa9c87
Automatically calculate badwords and pad_token_id
2022-06-21 14:35:52 -04:00
Gnome Ann
ea7d278ff4
Fix 20B TPU model
2022-06-21 13:16:45 -04:00
Gnome Ann
5e71f7fe97
Use slow tokenizer if fast tokenizer is not available
2022-06-17 21:08:37 -04:00
Gnome Ann
2d3db7b4ba
Implement support for sampler order in the backend code
2022-06-13 19:12:23 -04:00
Gnome Ann
fdb2a7fa4c
Top-A sampling
2022-06-10 22:28:20 -04:00
Gnome Ann
707316de31
Kaggle TPU support
2022-05-31 12:20:16 -04:00
Gnome Ann
d4e8f56789
Remove debugging code from tpu_mtj_backend.py
2022-05-14 12:00:44 -04:00
Gnome Ann
0c5ca5261e
Loading a sharded model will now display only one progress bar
2022-05-13 23:32:16 -04:00
Gnome Ann
b1d8797a54
Allow TPU Colab to load sharded HF models
2022-05-12 23:51:40 -04:00
Gnome Ann
4fa5f1cd6a
Add TPU support for OPT-350M
...
The 350M model seems to have a different structure than the other ones ???
2022-05-12 22:21:15 -04:00
Gnome Ann
f5e689a725
Upload maps/opt.json and update requirements
2022-05-12 19:09:31 -04:00
Gnome Ann
b97b2a02d6
Add `--revision` command line flag
2022-05-10 22:14:56 -04:00
Gnome Ann
c117bfd0ad
Fix lazy loader
2022-04-08 19:38:15 -04:00
Gnome Ann
fabbdf2bb1
Lazy loader Python 3.6 compatibility
...
The current lazy loader relies on a feature of the Python zipfile module
that was added in Python 3.7.0:
https://bugs.python.org/issue22908
This commit adds compatibility for Python 3.6.
2022-04-02 15:02:54 -04:00
Gnome Ann
67e28d2b5c
Typical sampling needs to use nansum instead of sum
...
If `probs` is zero then `log_probs` will be negative infinity, and the
calculation of `neg_entropy` would then give NaN because zero times
infinity is a mathematically indeterminate value.
We need to use nansum so that those NaN values are treated as zeros to
ignore them in the entropy calculation.
2022-03-28 00:02:31 -04:00
Gnome Ann
d5989d4c62
Hide division by zero warning in JAX typical filter
...
This warning happens when `np.log` gets an input containing zeros.
In that case, NumPy will throw a warning and output negative infinity.
Negative infinity is the correct behaviour here, so we can safely ignore
the warning.
2022-03-27 16:57:12 -04:00
Gnome Ann
20e48b11d7
Typical sampling
2022-03-27 16:25:50 -04:00
Gnome Ann
73aecc0510
Divide NeoX replicated bias layers by 4 again instead of by 8
2022-03-20 01:04:55 -04:00
Gnome Ann
05fc46b253
Changing this again to divide by 8
2022-03-19 02:09:41 -04:00
Gnome Ann
6c20d0d657
Nevermind, dividing by 4 is actually correct...
2022-03-19 00:55:04 -04:00
Gnome Ann
f16b61ec77
Should divide NeoX replicated parameters by 8 (not by 4)
...
Also, suppresses the PyTorch 1.11 warning about transposing tensors with
ndim != 2 in the new code
2022-03-19 00:48:33 -04:00
Gnome Ann
c2c139e940
Change default PE type for NeoX to `neox_rotary`
2022-03-19 00:26:04 -04:00
Gnome Ann
85a4959efa
Merge branch 'united' into neox
2022-03-18 11:19:03 -04:00
Gnome Ann
c444260eac
Silence PyTorch warning about transposing tensors with dimension != 2
2022-03-17 15:16:56 -04:00
Gnome Ann
eaf190469d
Add PyTorch 1.11 support for lazy loader
2022-03-17 12:51:41 -04:00
Gnome Ann
95c4251db9
Print two newlines before loading HF models
2022-03-15 13:58:53 -04:00
Gnome Ann
9e2848e48f
Show parameter count when loading GPT-NeoX in Colab TPU instance
2022-03-15 13:55:27 -04:00
Gnome Ann
88f247d535
GPT-NeoX-20B support in Colab TPU instances
2022-03-14 23:14:20 -04:00
Gnome Ann
2b8c46338e
Change current working directory to KoboldAI folder
2022-03-13 01:22:11 -05:00
Gnome Ann
48d07adb54
Also fallback to generic GPT2 tokenizer in Colab TPU instances
2022-03-12 23:19:35 -05:00
Gnome Ann
a99eb8724d
Use DLPack to convert PyTorch tensors to JAX arrays
2022-03-10 15:12:42 -05:00
henk717
68281184bf
Remove Lowmem from TPU
2022-03-09 19:21:15 +01:00
Gnome Ann
0a258a6282
Support for loading HF models on TPU with `--colab_tpu`
2022-03-05 12:33:33 -05:00
Gnome Ann
ad10ac8871
Allow TPU models to specify settings/config in config.json
2022-02-23 18:22:18 -05:00
Gnome Ann
7ec549c726
Use dematerialized loading in TPU backend for lower device memory usage
2022-02-22 19:43:13 -05:00
henk717
fca7f8659f
Badwords unification
...
TPU's no longer use hardcoded badwords but instead use the var
2022-01-29 18:09:53 +01:00
Gnome Ann
3f18888eec
Repetition penalty slope and range
2022-01-24 15:30:38 -05:00
Gnome Ann
3ba0e3f9d9
Dynamic TPU backend should support dynamic warpers and abort button
2022-01-17 14:10:32 -05:00
Gnome Ann
31735c4239
Fix np.take ( https://github.com/google/jax/issues/3774 )
2022-01-17 13:54:02 -05:00
Gnome Ann
33f9f2dc82
Show message when TPU backend is compiling
2022-01-16 21:09:10 -05:00
Gnome Ann
f4eb896a69
Use original TPU backend if possible
2022-01-15 23:31:07 -05:00
Gnome Ann
e0fdce2cc6
Fix TPU generation modifier
2022-01-14 23:00:06 -05:00
Gnome Ann
932c393d6a
Add TPU support for dynamic WI scan and generation modifiers
2022-01-14 21:39:02 -05:00
Gnome Ann
0bef92419b
Convert the `jit`ted function into ordinary NumPy operations
2022-01-14 15:05:21 -05:00
Gnome Ann
57a6886007
Move sampling into a `jax.jit`ted function
2022-01-14 02:23:19 -05:00
Gnome Ann
09c4fdcb2e
Split `generate_xmap` into two xmaps
2022-01-13 00:56:00 -05:00
Gnome Ann
a3d6dc93e8
xmaps for moving things onto TPU
2022-01-12 21:45:30 -05:00
Gnome Ann
8742453f95
Add safeguards for token budget and text formatting
...
* Error messages are now shown when memory, author's note, etc. exceeds
budget by itself
* Formatting options no longer break if there are empty chunks in the
story (although there shouldn't be any in the first place)
* Number of generated tokens is now kept track of from Python
2021-12-26 18:29:54 -05:00
Gnome Ann
fbf3e7615b
Add API for generated tokens and output text
2021-12-12 19:27:20 -05:00
Gnome Ann
d2d338d314
Improve TPU backend compilation times with `numseqs > 1`
...
A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
2021-11-30 19:22:40 -05:00