Multiple GPU support

This commit is contained in:
Gnome Ann 2021-10-05 09:38:57 -04:00
parent 0937bb33e7
commit a283d34b27
2 changed files with 144 additions and 91 deletions

View File

@ -178,6 +178,88 @@ def getmodelname():
modelname = vars.model
return modelname
#==================================================================#
# Breakmodel configuration functions
#==================================================================#
def device_list(n_layers, primary=None, selected=None):
device_count = torch.cuda.device_count()
if(device_count < 2):
primary = None
gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0]
print(f"{colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{colors.END}")
for i in range(device_count):
name = torch.cuda.get_device_name(i)
if(len(name) > 47):
name = "..." + name[-44:]
row_color = colors.END
sep_color = colors.YELLOW
print(f"{row_color}{colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{colors.END}")
row_color = colors.END
sep_color = colors.YELLOW
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
def device_config(model):
global breakmodel, generator
import breakmodel
n_layers = model.config.num_layers
model.half().to('cpu')
gc.collect()
if(args.breakmodel_layers is not None):
breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))]
else:
device_count = torch.cuda.device_count()
if(device_count > 1):
print(colors.CYAN + "\nPlease select one of your GPUs to be your primary GPU.")
print("VRAM usage in your primary GPU will be higher than for your other ones.")
print("It is recommended you make your fastest GPU your primary GPU.")
device_list(n_layers)
while(True):
primaryselect = input("device ID> ")
if(primaryselect.isnumeric() and 0 <= int(primaryselect) < device_count):
breakmodel.primary_device = int(primaryselect)
else:
print(f"{colors.RED}Please enter an integer between 0 and {device_count-1}.{colors.END}")
else:
breakmodel.primary_device = 0
print(colors.PURPLE + "\nIf you don't have enough VRAM to run the model on a single GPU")
print("you can split the model between your CPU and your GPU(s), or between")
print("multiple GPUs if you have more than one.")
print("By putting more 'layers' on a GPU or CPU, more computations will be")
print("done on that device and more VRAM or RAM will be required on that device")
print("(roughly proportional to number of layers).")
print("It should be noted that GPUs are orders of magnitude faster than the CPU.")
print(f"This model has{colors.YELLOW} {n_layers} {colors.PURPLE}layers.{colors.END}\n")
for i in range(device_count):
device_list(n_layers, primary=breakmodel.primary_device, selected=i)
print(f"{colors.CYAN}\nHow many of the remaining{colors.YELLOW} {n_layers} {colors.CYAN}layers would you like to put into device {i}?\nYou can also enter -1 to allocate all remaining layers to this device.{colors.END}\n")
while(True):
layerselect = input("# of layers> ")
if((layerselect.isnumeric() or layerselect.strip() == '-1') and -1 <= int(layerselect) <= n_layers):
layerselect = int(layerselect)
layerselect = n_layers if layerselect == -1 else layerselect
breakmodel.gpu_blocks.append(layerselect)
n_layers -= layerselect
break
else:
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
if(n_layers == 0):
break
print(colors.PURPLE + "\nFinal device configuration:")
device_list(n_layers)
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.primary_device)
gc.collect()
GPTNeoModel.forward = breakmodel.new_forward
generator = model.generate
#==================================================================#
# Startup
#==================================================================#
@ -414,36 +496,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.usegpu):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
import breakmodel
n_layers = model.config.num_layers
breakmodel.total_blocks = n_layers
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.embedding_device)
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.embedding_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.positional_device)
gc.collect()
if(args.breakmodel_layers is not None):
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
else:
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
print("The more of them you put into system RAM, the slower it will run,")
print("but it will require less VRAM")
print("(roughly proportional to number of layers).")
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
while(True):
layerselect = input("# of layers> ")
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
breakmodel.ram_blocks = int(layerselect)
break
else:
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
GPTNeoModel.forward = breakmodel.new_forward
generator = model.generate
device_config(model)
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
else:
@ -465,37 +518,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.usegpu):
generator = pipeline('text-generation', model=vars.model, device=0)
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
import breakmodel
model = AutoModelForCausalLM.from_pretrained(vars.model)
n_layers = model.config.num_layers
breakmodel.total_blocks = n_layers
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.embedding_device)
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.embedding_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.positional_device)
gc.collect()
if(args.breakmodel_layers is not None):
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
else:
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
print("The more of them you put into system RAM, the slower it will run,")
print("but it will require less VRAM")
print("(roughly proportional to number of layers).")
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
while(True):
layerselect = input("# of layers> ")
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
breakmodel.ram_blocks = int(layerselect)
break
else:
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
GPTNeoModel.forward = breakmodel.new_forward
generator = model.generate
device_config(model)
else:
generator = pipeline('text-generation', model=vars.model)
else:
@ -1245,7 +1269,7 @@ def generate(txt, min, max):
# its first argument if we're using breakmodel, otherwise a string
# is fine
if(vars.hascuda and vars.breakmodel):
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.embedding_device)
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device)
else:
gen_in = txt

View File

@ -215,6 +215,8 @@ import torch
import torch.cuda.comm
import copy
import gc
import itertools
import bisect
from transformers.modeling_outputs import BaseModelOutputWithPast
@ -222,23 +224,9 @@ from transformers.utils import logging
logger = logging.get_logger(__name__)
class MaxSharedRamBlocksException(Exception):
def __init__(self, i: int):
self.corrected_max_shared_ram_blocks = i
super().__init__('max_shared_ram_blocks is set too high, please set it to '+str(i))
breakmodel = True
devices = ['cpu', 'cuda']
total_blocks = 24
ram_blocks = 7
max_shared_ram_blocks = None
# I highly suggest these all be set to the same device unless you really know what you're doing!
# (They can all be set to any CPU or GPU device, except layernormfinal_device which can only be a GPU device)
embedding_device = devices[1] # Dealing with text embedding is computationally expensive, I suggest you set this to your fastest device
positional_device = devices[1] # Only used for GPT-Neo (not used for GPT-J)
layernormfinal_device = devices[1] # This setting is unique in that this MUST be set to a GPU device, this cannot be set to 'cpu'
gpu_blocks = []
primary_device = 0
def new_forward(
@ -256,21 +244,41 @@ def new_forward(
return_dict=None,
embs=None,
):
global max_shared_ram_blocks
assert len(gpu_blocks) <= torch.cuda.device_count()
assert sum(gpu_blocks) <= len(self.h)
ram_blocks = len(self.h) - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
if breakmodel:
if max_shared_ram_blocks is None:
max_shared_ram_blocks = total_blocks
if not hasattr(self, 'extrastorage'):
setattr(self,"extrastorage",{})
torch.cuda.empty_cache()
for i in range(ram_blocks):
self.h[i].to(devices[0])
self.h[i].to("cpu")
self.extrastorage[i] = copy.deepcopy(self.h[i])
smalltensor = torch.tensor(0).to(primary_device)
for param1 in self.h[i].parameters():
param1.data = smalltensor
self.h[i].to(primary_device)
for param in self.extrastorage[i].parameters():
param.requires_grad = False
param.data = param.data.detach().pin_memory()
gc.collect()
torch.cuda.empty_cache()
for i in range(ram_blocks,len(self.h)):
self.h[i].to(devices[1])
if ram_blocks:
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
i = ram_blocks
for j in range(len(gpu_blocks)):
for _ in range(gpu_blocks[j]):
self.h[i].to(j)
i += 1
@ -306,7 +314,7 @@ def new_forward(
else:
past_length = past_key_values[0][0].size(-2)
device = positional_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
device = primary_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
@ -344,7 +352,7 @@ def new_forward(
if inputs_embeds is None:
if breakmodel:
input_ids = input_ids.to(embedding_device)
input_ids = input_ids.to(primary_device)
inputs_embeds = self.wte(input_ids)
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
@ -360,10 +368,10 @@ def new_forward(
hidden_states = inputs_embeds
else:
if breakmodel:
position_ids = position_ids.to(positional_device)
position_ids = position_ids.to(primary_device)
position_embeds = self.wpe(position_ids)
if breakmodel:
position_embeds = position_embeds.to(embedding_device)
position_embeds = position_embeds.to(primary_device)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
@ -377,8 +385,21 @@ def new_forward(
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if breakmodel and ram_blocks:
copystream = torch.cuda.Stream(device=0,priority = -1)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if breakmodel:
if i in range(ram_blocks):
index1 = (i+1)%ram_blocks
for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()):
param1.data = param2.data
for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()):
with torch.cuda.stream(copystream):
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
attn_type = self.config.attention_layers[i]
attn_mask = global_attention_mask
@ -410,7 +431,7 @@ def new_forward(
)
else:
if breakmodel:
device = devices[0] if i < ram_blocks else devices[1]
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
outputs = block(
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past,
@ -428,11 +449,19 @@ def new_forward(
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if breakmodel:
if i in range(ram_blocks):
torch.cuda.synchronize()
torch.cuda.empty_cache()
if breakmodel:
hidden_states = hidden_states.to(layernormfinal_device)
if ram_blocks:
del copystream
torch.cuda.empty_cache()
hidden_states = hidden_states.to(primary_device)
hidden_states = self.ln_f(hidden_states)
if breakmodel:
hidden_states = hidden_states.to(embedding_device)
hidden_states = hidden_states.to(primary_device)
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state