From 7bf6c9a23f451b1d8fc61b9cdf916ca4c864a5b4 Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 15 Sep 2022 13:47:48 -0400 Subject: [PATCH] Remove TPU Colab's dependency on optax and chex --- requirements_mtj.txt | 2 -- tpu_mtj_backend.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 90c68634..613e9203 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -2,7 +2,6 @@ torch >= 1.9, <= 1.11 numpy tqdm requests -optax >= 0.0.5, <= 0.0.9 dm-haiku == 0.0.5 jax == 0.2.21 jaxlib >= 0.1.69, <= 0.3.7 @@ -17,4 +16,3 @@ eventlet lupa==1.10 markdown bleach==4.1.0 -chex==0.1.4 diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index da0511df..0c6667a2 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -30,7 +30,7 @@ SOFTWARE. import utils import multiprocessing -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar import progressbar import time import os @@ -45,7 +45,6 @@ from jax.config import config from jax.experimental import maps import jax.numpy as jnp import numpy as np -import optax import haiku as hk from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM from tokenizers import Tokenizer @@ -120,6 +119,14 @@ def __batch_xmap(shard_dim=1): return inner +class _EmptyState(NamedTuple): + pass + +class _DummyOptimizer: + def init(*args, **kwargs): + return _EmptyState() + + def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): ''' This gets called by generate_loop_fn to apply repetition penalty @@ -1148,7 +1155,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo cores_per_replica = params["cores_per_replica"] seq = params["seq"] - params["optimizer"] = optax.scale(0) + params["optimizer"] = _DummyOptimizer() mesh_shape = (1, cores_per_replica) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())