From 7ec549c7262280678bcdd32ca44945484f1c3c22 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 22 Feb 2022 19:43:13 -0500 Subject: [PATCH] Use dematerialized loading in TPU backend for lower device memory usage --- tpu_mtj_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 40059425..00f9a510 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ return carry class PenalizingCausalTransformer(CausalTransformer): - def __init__(self, config): + def __init__(self, config, **kwargs): # Initialize - super().__init__(config) + super().__init__(config, **kwargs) def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): compiling_callback() numseqs = numseqs_aux.shape[0] @@ -832,6 +832,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) if not path.endswith("/"): path += "/" - network = PenalizingCausalTransformer(params) + network = PenalizingCausalTransformer(params, dematerialized=True) network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))