mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-25 07:58:03 +01:00
Merge pull request #81 from VE-FORBRYDERNE/dematerialized
Use dematerialized loading in TPU backend for lower device memory usage
This commit is contained in:
commit
6151d16df0
@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||||||
return carry
|
return carry
|
||||||
|
|
||||||
class PenalizingCausalTransformer(CausalTransformer):
|
class PenalizingCausalTransformer(CausalTransformer):
|
||||||
def __init__(self, config):
|
def __init__(self, config, **kwargs):
|
||||||
# Initialize
|
# Initialize
|
||||||
super().__init__(config)
|
super().__init__(config, **kwargs)
|
||||||
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||||
compiling_callback()
|
compiling_callback()
|
||||||
numseqs = numseqs_aux.shape[0]
|
numseqs = numseqs_aux.shape[0]
|
||||||
@ -832,6 +832,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||||||
if not path.endswith("/"):
|
if not path.endswith("/"):
|
||||||
path += "/"
|
path += "/"
|
||||||
|
|
||||||
network = PenalizingCausalTransformer(params)
|
network = PenalizingCausalTransformer(params, dematerialized=True)
|
||||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user