Multiple GPU support

This commit is contained in:
Gnome Ann
2021-10-05 09:38:57 -04:00
parent 0937bb33e7
commit a283d34b27
2 changed files with 144 additions and 91 deletions

View File

@ -215,6 +215,8 @@ import torch
import torch.cuda.comm
import copy
import gc
import itertools
import bisect
from transformers.modeling_outputs import BaseModelOutputWithPast
@ -222,23 +224,9 @@ from transformers.utils import logging
logger = logging.get_logger(__name__)
class MaxSharedRamBlocksException(Exception):
def __init__(self, i: int):
self.corrected_max_shared_ram_blocks = i
super().__init__('max_shared_ram_blocks is set too high, please set it to '+str(i))
breakmodel = True
devices = ['cpu', 'cuda']
total_blocks = 24
ram_blocks = 7
max_shared_ram_blocks = None
# I highly suggest these all be set to the same device unless you really know what you're doing!
# (They can all be set to any CPU or GPU device, except layernormfinal_device which can only be a GPU device)
embedding_device = devices[1] # Dealing with text embedding is computationally expensive, I suggest you set this to your fastest device
positional_device = devices[1] # Only used for GPT-Neo (not used for GPT-J)
layernormfinal_device = devices[1] # This setting is unique in that this MUST be set to a GPU device, this cannot be set to 'cpu'
gpu_blocks = []
primary_device = 0
def new_forward(
@ -256,21 +244,41 @@ def new_forward(
return_dict=None,
embs=None,
):
global max_shared_ram_blocks
assert len(gpu_blocks) <= torch.cuda.device_count()
assert sum(gpu_blocks) <= len(self.h)
ram_blocks = len(self.h) - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
if breakmodel:
if max_shared_ram_blocks is None:
max_shared_ram_blocks = total_blocks
if not hasattr(self, 'extrastorage'):
setattr(self,"extrastorage",{})
torch.cuda.empty_cache()
for i in range(ram_blocks):
self.h[i].to(devices[0])
self.h[i].to("cpu")
self.extrastorage[i] = copy.deepcopy(self.h[i])
smalltensor = torch.tensor(0).to(primary_device)
for param1 in self.h[i].parameters():
param1.data = smalltensor
self.h[i].to(primary_device)
for param in self.extrastorage[i].parameters():
param.requires_grad = False
param.data = param.data.detach().pin_memory()
gc.collect()
torch.cuda.empty_cache()
for i in range(ram_blocks,len(self.h)):
self.h[i].to(devices[1])
if ram_blocks:
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
i = ram_blocks
for j in range(len(gpu_blocks)):
for _ in range(gpu_blocks[j]):
self.h[i].to(j)
i += 1
@ -306,7 +314,7 @@ def new_forward(
else:
past_length = past_key_values[0][0].size(-2)
device = positional_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
device = primary_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
@ -344,7 +352,7 @@ def new_forward(
if inputs_embeds is None:
if breakmodel:
input_ids = input_ids.to(embedding_device)
input_ids = input_ids.to(primary_device)
inputs_embeds = self.wte(input_ids)
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
@ -360,10 +368,10 @@ def new_forward(
hidden_states = inputs_embeds
else:
if breakmodel:
position_ids = position_ids.to(positional_device)
position_ids = position_ids.to(primary_device)
position_embeds = self.wpe(position_ids)
if breakmodel:
position_embeds = position_embeds.to(embedding_device)
position_embeds = position_embeds.to(primary_device)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
@ -377,8 +385,21 @@ def new_forward(
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if breakmodel and ram_blocks:
copystream = torch.cuda.Stream(device=0,priority = -1)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if breakmodel:
if i in range(ram_blocks):
index1 = (i+1)%ram_blocks
for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()):
param1.data = param2.data
for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()):
with torch.cuda.stream(copystream):
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
attn_type = self.config.attention_layers[i]
attn_mask = global_attention_mask
@ -410,7 +431,7 @@ def new_forward(
)
else:
if breakmodel:
device = devices[0] if i < ram_blocks else devices[1]
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
outputs = block(
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past,
@ -428,11 +449,19 @@ def new_forward(
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if breakmodel:
if i in range(ram_blocks):
torch.cuda.synchronize()
torch.cuda.empty_cache()
if breakmodel:
hidden_states = hidden_states.to(layernormfinal_device)
if ram_blocks:
del copystream
torch.cuda.empty_cache()
hidden_states = hidden_states.to(primary_device)
hidden_states = self.ln_f(hidden_states)
if breakmodel:
hidden_states = hidden_states.to(embedding_device)
hidden_states = hidden_states.to(primary_device)
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state