Merge branch 'main' into dependency-fix

This commit is contained in:
vfbd
2022-09-15 17:33:48 -04:00
7 changed files with 314 additions and 89 deletions

View File

@@ -30,7 +30,7 @@ SOFTWARE.
import utils
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
import progressbar
import time
import os
@@ -45,7 +45,6 @@ from jax.config import config
from jax.experimental import maps
import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
from tokenizers import Tokenizer
@@ -136,6 +135,14 @@ def __batch_xmap(shard_dim=1):
return inner
class _EmptyState(NamedTuple):
pass
class _DummyOptimizer:
def init(*args, **kwargs):
return _EmptyState()
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
@@ -1167,7 +1174,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]
params["optimizer"] = optax.scale(0)
params["optimizer"] = _DummyOptimizer()
mesh_shape = (1, cores_per_replica)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())