mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Remove TPU Colab's dependency on optax and chex
This commit is contained in:
@ -2,7 +2,6 @@ torch >= 1.9, <= 1.11
|
|||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
requests
|
requests
|
||||||
optax >= 0.0.5, <= 0.0.9
|
|
||||||
dm-haiku == 0.0.5
|
dm-haiku == 0.0.5
|
||||||
jax == 0.2.21
|
jax == 0.2.21
|
||||||
jaxlib >= 0.1.69, <= 0.3.7
|
jaxlib >= 0.1.69, <= 0.3.7
|
||||||
@ -17,4 +16,3 @@ eventlet
|
|||||||
lupa==1.10
|
lupa==1.10
|
||||||
markdown
|
markdown
|
||||||
bleach==4.1.0
|
bleach==4.1.0
|
||||||
chex==0.1.4
|
|
||||||
|
@ -30,7 +30,7 @@ SOFTWARE.
|
|||||||
import utils
|
import utils
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
|
||||||
import progressbar
|
import progressbar
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
@ -45,7 +45,6 @@ from jax.config import config
|
|||||||
from jax.experimental import maps
|
from jax.experimental import maps
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import optax
|
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
@ -120,6 +119,14 @@ def __batch_xmap(shard_dim=1):
|
|||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
class _EmptyState(NamedTuple):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _DummyOptimizer:
|
||||||
|
def init(*args, **kwargs):
|
||||||
|
return _EmptyState()
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply repetition penalty
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
@ -1148,7 +1155,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
|
|
||||||
cores_per_replica = params["cores_per_replica"]
|
cores_per_replica = params["cores_per_replica"]
|
||||||
seq = params["seq"]
|
seq = params["seq"]
|
||||||
params["optimizer"] = optax.scale(0)
|
params["optimizer"] = _DummyOptimizer()
|
||||||
mesh_shape = (1, cores_per_replica)
|
mesh_shape = (1, cores_per_replica)
|
||||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||||
|
Reference in New Issue
Block a user