mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 12:10:49 +01:00
Remove TPU Colab's dependency on optax and chex
This commit is contained in:
parent
551565c5ac
commit
7bf6c9a23f
@ -2,7 +2,6 @@ torch >= 1.9, <= 1.11
|
||||
numpy
|
||||
tqdm
|
||||
requests
|
||||
optax >= 0.0.5, <= 0.0.9
|
||||
dm-haiku == 0.0.5
|
||||
jax == 0.2.21
|
||||
jaxlib >= 0.1.69, <= 0.3.7
|
||||
@ -17,4 +16,3 @@ eventlet
|
||||
lupa==1.10
|
||||
markdown
|
||||
bleach==4.1.0
|
||||
chex==0.1.4
|
||||
|
@ -30,7 +30,7 @@ SOFTWARE.
|
||||
import utils
|
||||
|
||||
import multiprocessing
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
|
||||
import progressbar
|
||||
import time
|
||||
import os
|
||||
@ -45,7 +45,6 @@ from jax.config import config
|
||||
from jax.experimental import maps
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import haiku as hk
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
from tokenizers import Tokenizer
|
||||
@ -120,6 +119,14 @@ def __batch_xmap(shard_dim=1):
|
||||
return inner
|
||||
|
||||
|
||||
class _EmptyState(NamedTuple):
|
||||
pass
|
||||
|
||||
class _DummyOptimizer:
|
||||
def init(*args, **kwargs):
|
||||
return _EmptyState()
|
||||
|
||||
|
||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
@ -1148,7 +1155,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
|
||||
cores_per_replica = params["cores_per_replica"]
|
||||
seq = params["seq"]
|
||||
params["optimizer"] = optax.scale(0)
|
||||
params["optimizer"] = _DummyOptimizer()
|
||||
mesh_shape = (1, cores_per_replica)
|
||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||
|
Loading…
x
Reference in New Issue
Block a user