Use dematerialized loading in TPU backend for lower device memory usage
This commit is contained in:
parent
fd7ba9f70e
commit
7ec549c726
|
@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||
return carry
|
||||
|
||||
class PenalizingCausalTransformer(CausalTransformer):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, **kwargs):
|
||||
# Initialize
|
||||
super().__init__(config)
|
||||
super().__init__(config, **kwargs)
|
||||
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||
compiling_callback()
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
|
@ -832,6 +832,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
if not path.endswith("/"):
|
||||
path += "/"
|
||||
|
||||
network = PenalizingCausalTransformer(params)
|
||||
network = PenalizingCausalTransformer(params, dematerialized=True)
|
||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||
|
|
Loading…
Reference in New Issue