A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
The original one supported versions of JAX up to 0.2.12, and possibly also some
earlier versions. This new code supports exclusively JAX 0.2.21 and does not
work with any earlier or later versions of JAX. However, this new code benefits
from not needing to recompile when changing "Amount To Generate" and also from
supporting stopping generation early, which makes an implementation of Dynamic
World Info Scan finally possible.