diff --git a/breakmodel.py b/breakmodel.py index d112e4b5..087a112a 100644 --- a/breakmodel.py +++ b/breakmodel.py @@ -215,6 +215,7 @@ import torch import torch.cuda.comm import copy import gc +import sys import itertools import bisect @@ -237,6 +238,7 @@ def move_hidden_layers(transformer): transformer.extrastorage = {} torch.cuda.empty_cache() + able_to_pin_layers = True for i in range(ram_blocks): transformer.h[i].to("cpu") transformer.extrastorage[i] = copy.deepcopy(transformer.h[i]) @@ -246,7 +248,13 @@ def move_hidden_layers(transformer): transformer.h[i].to(primary_device) for param in transformer.extrastorage[i].parameters(): param.requires_grad = False - param.data = param.data.detach().pin_memory() + param.data = param.data.detach() + if able_to_pin_layers: + try: + param.data = param.data.pin_memory() + except: + able_to_pin_layers = False + print(f"WARNING: You only have enough shared GPU memory for {i} out of {ram_blocks} CPU layers. Expect suboptimal speed.", file=sys.stderr) gc.collect() torch.cuda.empty_cache()