Use dematerialized loading in TPU backend for lower device memory usage

This commit is contained in:
Gnome Ann 2022-02-22 19:43:13 -05:00
parent fd7ba9f70e
commit 7ec549c726

View File

@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
return carry
class PenalizingCausalTransformer(CausalTransformer):
def __init__(self, config):
def __init__(self, config, **kwargs):
# Initialize
super().__init__(config)
super().__init__(config, **kwargs)
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
compiling_callback()
numseqs = numseqs_aux.shape[0]
@ -832,6 +832,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
if not path.endswith("/"):
path += "/"
network = PenalizingCausalTransformer(params)
network = PenalizingCausalTransformer(params, dematerialized=True)
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))