Use dematerialized loading in TPU backend for lower device memory usage
This commit is contained in:
parent
fd7ba9f70e
commit
7ec549c726
|
@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
||||||
return carry
|
return carry
|
||||||
|
|
||||||
class PenalizingCausalTransformer(CausalTransformer):
|
class PenalizingCausalTransformer(CausalTransformer):
|
||||||
def __init__(self, config):
|
def __init__(self, config, **kwargs):
|
||||||
# Initialize
|
# Initialize
|
||||||
super().__init__(config)
|
super().__init__(config, **kwargs)
|
||||||
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||||
compiling_callback()
|
compiling_callback()
|
||||||
numseqs = numseqs_aux.shape[0]
|
numseqs = numseqs_aux.shape[0]
|
||||||
|
@ -832,6 +832,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||||
if not path.endswith("/"):
|
if not path.endswith("/"):
|
||||||
path += "/"
|
path += "/"
|
||||||
|
|
||||||
network = PenalizingCausalTransformer(params)
|
network = PenalizingCausalTransformer(params, dematerialized=True)
|
||||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||||
|
|
Loading…
Reference in New Issue