Remove TPU Colab's dependency on optax and chex

This commit is contained in:
vfbd 2022-09-15 13:47:48 -04:00
parent 551565c5ac
commit 7bf6c9a23f
2 changed files with 10 additions and 5 deletions

View File

@ -2,7 +2,6 @@ torch >= 1.9, <= 1.11
numpy
tqdm
requests
optax >= 0.0.5, <= 0.0.9
dm-haiku == 0.0.5
jax == 0.2.21
jaxlib >= 0.1.69, <= 0.3.7
@ -17,4 +16,3 @@ eventlet
lupa==1.10
markdown
bleach==4.1.0
chex==0.1.4

View File

@ -30,7 +30,7 @@ SOFTWARE.
import utils
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
import progressbar
import time
import os
@ -45,7 +45,6 @@ from jax.config import config
from jax.experimental import maps
import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer
@ -120,6 +119,14 @@ def __batch_xmap(shard_dim=1):
return inner
class _EmptyState(NamedTuple):
pass
class _DummyOptimizer:
def init(*args, **kwargs):
return _EmptyState()
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
@ -1148,7 +1155,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]
params["optimizer"] = optax.scale(0)
params["optimizer"] = _DummyOptimizer()
mesh_shape = (1, cores_per_replica)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())