Remove TPU Colab's dependency on optax and chex

This commit is contained in:
vfbd
2022-09-15 13:47:48 -04:00
parent 551565c5ac
commit 7bf6c9a23f
2 changed files with 10 additions and 5 deletions

View File

@ -2,7 +2,6 @@ torch >= 1.9, <= 1.11
numpy numpy
tqdm tqdm
requests requests
optax >= 0.0.5, <= 0.0.9
dm-haiku == 0.0.5 dm-haiku == 0.0.5
jax == 0.2.21 jax == 0.2.21
jaxlib >= 0.1.69, <= 0.3.7 jaxlib >= 0.1.69, <= 0.3.7
@ -17,4 +16,3 @@ eventlet
lupa==1.10 lupa==1.10
markdown markdown
bleach==4.1.0 bleach==4.1.0
chex==0.1.4

View File

@ -30,7 +30,7 @@ SOFTWARE.
import utils import utils
import multiprocessing import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
import progressbar import progressbar
import time import time
import os import os
@ -45,7 +45,6 @@ from jax.config import config
from jax.experimental import maps from jax.experimental import maps
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import optax
import haiku as hk import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer from tokenizers import Tokenizer
@ -120,6 +119,14 @@ def __batch_xmap(shard_dim=1):
return inner return inner
class _EmptyState(NamedTuple):
pass
class _DummyOptimizer:
def init(*args, **kwargs):
return _EmptyState()
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
''' '''
This gets called by generate_loop_fn to apply repetition penalty This gets called by generate_loop_fn to apply repetition penalty
@ -1148,7 +1155,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
cores_per_replica = params["cores_per_replica"] cores_per_replica = params["cores_per_replica"]
seq = params["seq"] seq = params["seq"]
params["optimizer"] = optax.scale(0) params["optimizer"] = _DummyOptimizer()
mesh_shape = (1, cores_per_replica) mesh_shape = (1, cores_per_replica)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())