mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
TPU Fix?
This commit is contained in:
@@ -460,14 +460,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||||||
return carry
|
return carry
|
||||||
|
|
||||||
class PenalizingCausalTransformer(CausalTransformer):
|
class PenalizingCausalTransformer(CausalTransformer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, badwordsids, config, **kwargs):
|
||||||
# Initialize
|
# Initialize
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||||
compiling_callback()
|
compiling_callback()
|
||||||
numseqs = numseqs_aux.shape[0]
|
numseqs = numseqs_aux.shape[0]
|
||||||
# These are the tokens that we don't want the AI to ever write
|
# These are the tokens that we don't want the AI to ever write
|
||||||
badwords = jnp.array(koboldai_vars.badwordsids).squeeze()
|
badwords = jnp.array(badwordsids).squeeze()
|
||||||
@hk.transform
|
@hk.transform
|
||||||
def generate_sample(context, ctx_length):
|
def generate_sample(context, ctx_length):
|
||||||
# Give the initial context to the transformer
|
# Give the initial context to the transformer
|
||||||
@@ -941,7 +941,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
|||||||
|
|
||||||
koboldai_vars.status_message = ""
|
koboldai_vars.status_message = ""
|
||||||
|
|
||||||
def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
import koboldai_settings
|
||||||
|
|
||||||
|
def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badwordsids_default driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||||
|
|
||||||
if kwargs.get("pad_token_id"):
|
if kwargs.get("pad_token_id"):
|
||||||
@@ -1119,12 +1121,12 @@ def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109",
|
|||||||
|
|
||||||
global badwords
|
global badwords
|
||||||
# These are the tokens that we don't want the AI to ever write
|
# These are the tokens that we don't want the AI to ever write
|
||||||
badwords = jnp.array(koboldai_vars.badwordsids).squeeze()
|
badwords = jnp.array(badwordsids).squeeze()
|
||||||
|
|
||||||
if not path.endswith("/"):
|
if not path.endswith("/"):
|
||||||
path += "/"
|
path += "/"
|
||||||
|
|
||||||
network = PenalizingCausalTransformer(params, dematerialized=True)
|
network = PenalizingCausalTransformer(badwordsids, params, dematerialized=True)
|
||||||
|
|
||||||
if not hf_checkpoint and koboldai_vars.model != "TPUMeshTransformerGPTNeoX":
|
if not hf_checkpoint and koboldai_vars.model != "TPUMeshTransformerGPTNeoX":
|
||||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||||
|
Reference in New Issue
Block a user