Gnome Ann
73aecc0510
Divide NeoX replicated bias layers by 4 again instead of by 8
2022-03-20 01:04:55 -04:00
Gnome Ann
05fc46b253
Changing this again to divide by 8
2022-03-19 02:09:41 -04:00
Gnome Ann
6c20d0d657
Nevermind, dividing by 4 is actually correct...
2022-03-19 00:55:04 -04:00
Gnome Ann
f16b61ec77
Should divide NeoX replicated parameters by 8 (not by 4)
...
Also, suppresses the PyTorch 1.11 warning about transposing tensors with
ndim != 2 in the new code
2022-03-19 00:48:33 -04:00
Gnome Ann
c2c139e940
Change default PE type for NeoX to `neox_rotary`
2022-03-19 00:26:04 -04:00
Gnome Ann
85a4959efa
Merge branch 'united' into neox
2022-03-18 11:19:03 -04:00
Gnome Ann
c444260eac
Silence PyTorch warning about transposing tensors with dimension != 2
2022-03-17 15:16:56 -04:00
Gnome Ann
eaf190469d
Add PyTorch 1.11 support for lazy loader
2022-03-17 12:51:41 -04:00
Gnome Ann
95c4251db9
Print two newlines before loading HF models
2022-03-15 13:58:53 -04:00
Gnome Ann
9e2848e48f
Show parameter count when loading GPT-NeoX in Colab TPU instance
2022-03-15 13:55:27 -04:00
Gnome Ann
88f247d535
GPT-NeoX-20B support in Colab TPU instances
2022-03-14 23:14:20 -04:00
Gnome Ann
2b8c46338e
Change current working directory to KoboldAI folder
2022-03-13 01:22:11 -05:00
Gnome Ann
48d07adb54
Also fallback to generic GPT2 tokenizer in Colab TPU instances
2022-03-12 23:19:35 -05:00
Gnome Ann
a99eb8724d
Use DLPack to convert PyTorch tensors to JAX arrays
2022-03-10 15:12:42 -05:00
henk717
68281184bf
Remove Lowmem from TPU
2022-03-09 19:21:15 +01:00
Gnome Ann
0a258a6282
Support for loading HF models on TPU with `--colab_tpu`
2022-03-05 12:33:33 -05:00
Gnome Ann
ad10ac8871
Allow TPU models to specify settings/config in config.json
2022-02-23 18:22:18 -05:00
Gnome Ann
7ec549c726
Use dematerialized loading in TPU backend for lower device memory usage
2022-02-22 19:43:13 -05:00
henk717
fca7f8659f
Badwords unification
...
TPU's no longer use hardcoded badwords but instead use the var
2022-01-29 18:09:53 +01:00
Gnome Ann
3f18888eec
Repetition penalty slope and range
2022-01-24 15:30:38 -05:00
Gnome Ann
3ba0e3f9d9
Dynamic TPU backend should support dynamic warpers and abort button
2022-01-17 14:10:32 -05:00
Gnome Ann
31735c4239
Fix np.take ( https://github.com/google/jax/issues/3774 )
2022-01-17 13:54:02 -05:00
Gnome Ann
33f9f2dc82
Show message when TPU backend is compiling
2022-01-16 21:09:10 -05:00
Gnome Ann
f4eb896a69
Use original TPU backend if possible
2022-01-15 23:31:07 -05:00
Gnome Ann
e0fdce2cc6
Fix TPU generation modifier
2022-01-14 23:00:06 -05:00
Gnome Ann
932c393d6a
Add TPU support for dynamic WI scan and generation modifiers
2022-01-14 21:39:02 -05:00
Gnome Ann
0bef92419b
Convert the `jit`ted function into ordinary NumPy operations
2022-01-14 15:05:21 -05:00
Gnome Ann
57a6886007
Move sampling into a `jax.jit`ted function
2022-01-14 02:23:19 -05:00
Gnome Ann
09c4fdcb2e
Split `generate_xmap` into two xmaps
2022-01-13 00:56:00 -05:00
Gnome Ann
a3d6dc93e8
xmaps for moving things onto TPU
2022-01-12 21:45:30 -05:00
Gnome Ann
8742453f95
Add safeguards for token budget and text formatting
...
* Error messages are now shown when memory, author's note, etc. exceeds
budget by itself
* Formatting options no longer break if there are empty chunks in the
story (although there shouldn't be any in the first place)
* Number of generated tokens is now kept track of from Python
2021-12-26 18:29:54 -05:00
Gnome Ann
fbf3e7615b
Add API for generated tokens and output text
2021-12-12 19:27:20 -05:00
Gnome Ann
d2d338d314
Improve TPU backend compilation times with `numseqs > 1`
...
A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
2021-11-30 19:22:40 -05:00
Gnome Ann
c1e7c1643f
Fix unbound axis error in tpu_mtj_backend.py when `numseqs > 1`
2021-11-30 14:06:46 -05:00
Gnome Ann
3c349e6aaf
Modify TPU backend code to support JAX 0.2.21
...
The original one supported versions of JAX up to 0.2.12, and possibly also some
earlier versions. This new code supports exclusively JAX 0.2.21 and does not
work with any earlier or later versions of JAX. However, this new code benefits
from not needing to recompile when changing "Amount To Generate" and also from
supporting stopping generation early, which makes an implementation of Dynamic
World Info Scan finally possible.
2021-11-30 10:13:02 -05:00
Gnome Ann
691febacd6
Fix a typo in tpu_mtj_backend.py
2021-11-22 12:53:19 -05:00
Gnome Ann
e068aa9f26
Add soft prompt support to TPU backend
2021-11-21 18:08:04 -05:00
Gnome Ann
a65c4de840
Integrate TPU backend
...
This commit puts the TPU backend code directly in to the KoboldAI code
to make it easier to modify.
2021-11-19 18:06:57 -05:00