Merge pull request #19 from VE-FORBRYDERNE/multi-gpu

Multiple GPU support
This commit is contained in:
henk717 2021-10-06 18:50:58 +02:00 committed by GitHub
commit bd063f7590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 167 additions and 132 deletions

View File

@ -178,6 +178,99 @@ def getmodelname():
modelname = vars.model
return modelname
#==================================================================#
# Breakmodel configuration functions
#==================================================================#
def device_list(n_layers, primary=None, selected=None):
device_count = torch.cuda.device_count()
if(device_count < 2):
primary = None
gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0]
print(f"{colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{colors.END}")
for i in range(device_count):
name = torch.cuda.get_device_name(i)
if(len(name) > 47):
name = "..." + name[-44:]
row_color = colors.END
sep_color = colors.YELLOW
print(f"{row_color}{colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{colors.END}")
row_color = colors.END
sep_color = colors.YELLOW
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
def device_config(model):
global breakmodel, generator
import breakmodel
n_layers = model.config.num_layers
model.half().to('cpu')
gc.collect()
if(args.layers is not None):
try:
breakmodel.gpu_blocks = list(map(int, args.layers.split(',')))
assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count()
assert sum(breakmodel.gpu_blocks) <= n_layers
n_layers -= sum(breakmodel.gpu_blocks)
except:
print("WARNING: --layers is malformatted. Please use the --help option to see correct usage of --layers. Defaulting to all layers on device 0.", file=sys.stderr)
breakmodel.gpu_blocks = [n_layers]
n_layers = 0
elif(args.breakmodel_layers is not None):
breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))]
n_layers -= sum(breakmodel.gpu_blocks)
else:
device_count = torch.cuda.device_count()
if(device_count > 1):
print(colors.CYAN + "\nPlease select one of your GPUs to be your primary GPU.")
print("VRAM usage in your primary GPU will be higher than for your other ones.")
print("It is recommended you make your fastest GPU your primary GPU.")
device_list(n_layers)
while(True):
primaryselect = input("device ID> ")
if(primaryselect.isnumeric() and 0 <= int(primaryselect) < device_count):
breakmodel.primary_device = int(primaryselect)
else:
print(f"{colors.RED}Please enter an integer between 0 and {device_count-1}.{colors.END}")
else:
breakmodel.primary_device = 0
print(colors.PURPLE + "\nIf you don't have enough VRAM to run the model on a single GPU")
print("you can split the model between your CPU and your GPU(s), or between")
print("multiple GPUs if you have more than one.")
print("By putting more 'layers' on a GPU or CPU, more computations will be")
print("done on that device and more VRAM or RAM will be required on that device")
print("(roughly proportional to number of layers).")
print("It should be noted that GPUs are orders of magnitude faster than the CPU.")
print(f"This model has{colors.YELLOW} {n_layers} {colors.PURPLE}layers.{colors.END}\n")
for i in range(device_count):
device_list(n_layers, primary=breakmodel.primary_device, selected=i)
print(f"{colors.CYAN}\nHow many of the remaining{colors.YELLOW} {n_layers} {colors.CYAN}layers would you like to put into device {i}?\nYou can also enter -1 to allocate all remaining layers to this device.{colors.END}\n")
while(True):
layerselect = input("# of layers> ")
if((layerselect.isnumeric() or layerselect.strip() == '-1') and -1 <= int(layerselect) <= n_layers):
layerselect = int(layerselect)
layerselect = n_layers if layerselect == -1 else layerselect
breakmodel.gpu_blocks.append(layerselect)
n_layers -= layerselect
break
else:
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
if(n_layers == 0):
break
print(colors.PURPLE + "\nFinal device configuration:")
device_list(n_layers)
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.primary_device)
gc.collect()
GPTNeoModel.forward = breakmodel.new_forward
generator = model.generate
#==================================================================#
# Startup
#==================================================================#
@ -188,8 +281,9 @@ parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI fo
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
parser.add_argument("--cpu", action='store_true', help="By default unattended launches are on the GPU use this option to force CPU usage.")
parser.add_argument("--breakmodel", action='store_true', help="For models that support GPU-CPU hybrid generation, use this feature instead of GPU or CPU generation")
parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used")
parser.add_argument("--breakmodel", action='store_true', help=argparse.SUPPRESS)
parser.add_argument("--breakmodel_layers", type=int, help=argparse.SUPPRESS)
parser.add_argument("--layers", type=str, help="If using a model that supports hybrid generation, this is a comma-separated list that specifies how many layers to put on each GPU device. For example to put 8 layers on device 0, 9 layers on device 1 and 11 layers on device 2, use --layers 8,9,11")
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
@ -222,6 +316,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
vars.hascuda = torch.cuda.is_available()
vars.bmsupported = vars.model in ("EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", "NeoCustom")
if(args.breakmodel is not None and args.breakmodel):
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr)
if(args.breakmodel_layers is not None):
print("WARNING: --breakmodel_layers is deprecated. Use --layers instead (see --help for details).", file=sys.stderr)
if(not vars.bmsupported and (args.layers is not None or args.breakmodel_layers is not None)):
print("WARNING: This model does not support hybrid generation. --layers will be ignored.", file=sys.stderr)
if(vars.hascuda):
print("{0}FOUND!{1}".format(colors.GREEN, colors.END))
else:
@ -232,18 +332,21 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
genselected = True
vars.usegpu = True
vars.breakmodel = False
if(vars.bmsupported):
vars.usegpu = False
vars.breakmodel = True
if(args.cpu):
vars.usegpu = False
vars.breakmodel = False
if(vars.bmsupported and args.breakmodel):
vars.usegpu = False
vars.breakmodel = True
elif(vars.hascuda):
if(vars.bmsupported):
print(colors.YELLOW + "You're using a model that supports GPU-CPU hybrid generation!\nCurrently only GPT-Neo models and GPT-J-6B support this feature.")
print("{0}Use GPU or CPU for generation?: (Default GPU){1}".format(colors.CYAN, colors.END))
if(vars.bmsupported):
print(f" 1 - GPU\n 2 - CPU\n 3 - Both (slower than GPU-only but uses less VRAM)\n")
print(colors.YELLOW + "You're using a model that supports hybrid generation!")
print("This feature allows you to split the model between the CPU and GPU(s)")
print("(slower than GPU-only but uses less VRAM) or between multiple GPUs")
print("(allowing you to use the combined VRAM of all your GPUs).")
print("Currently only GPT-Neo and GPT-J models support this feature.")
print("{0}Use hybrid generation or CPU-only generation?: (Default hybrid){1}".format(colors.CYAN, colors.END))
print(f" 1 - Hybrid generation\n 2 - CPU\n")
else:
print(" 1 - GPU\n 2 - CPU\n")
genselected = False
@ -256,6 +359,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
vars.usegpu = True
genselected = True
elif(genselect.isnumeric() and int(genselect) == 1):
if(vars.bmsupported):
vars.breakmodel = True
vars.usegpu = False
genselected = True
else:
vars.breakmodel = False
vars.usegpu = True
genselected = True
@ -263,10 +371,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
vars.breakmodel = False
vars.usegpu = False
genselected = True
elif(vars.bmsupported and genselect.isnumeric() and int(genselect) == 3):
vars.breakmodel = True
vars.usegpu = False
genselected = True
else:
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
@ -405,7 +509,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
if("model_type" in js):
model = vars.custmodpth
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth)
else:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth)
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
@ -414,36 +518,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.usegpu):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
import breakmodel
n_layers = model.config.num_layers
breakmodel.total_blocks = n_layers
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.gpu_device)
model.transformer.ln_f.to(breakmodel.gpu_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.gpu_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.gpu_device)
gc.collect()
if(args.breakmodel_layers is not None):
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
else:
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
print("The more of them you put into system RAM, the slower it will run,")
print("but it will require less VRAM")
print("(roughly proportional to number of layers).")
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
while(True):
layerselect = input("# of layers> ")
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
breakmodel.ram_blocks = int(layerselect)
break
else:
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
GPTNeoModel.forward = breakmodel.new_forward
generator = model.generate
device_config(model)
else:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
else:
@ -465,37 +540,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.usegpu):
generator = pipeline('text-generation', model=vars.model, device=0)
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
import breakmodel
model = AutoModelForCausalLM.from_pretrained(vars.model)
n_layers = model.config.num_layers
breakmodel.total_blocks = n_layers
model.half().to('cpu')
gc.collect()
model.transformer.wte.to(breakmodel.gpu_device)
model.transformer.ln_f.to(breakmodel.gpu_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.gpu_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
model.transformer.wpe.to(breakmodel.gpu_device)
gc.collect()
if(args.breakmodel_layers is not None):
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
else:
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
print("The more of them you put into system RAM, the slower it will run,")
print("but it will require less VRAM")
print("(roughly proportional to number of layers).")
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
while(True):
layerselect = input("# of layers> ")
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
breakmodel.ram_blocks = int(layerselect)
break
else:
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
GPTNeoModel.forward = breakmodel.new_forward
generator = model.generate
device_config(model)
else:
generator = pipeline('text-generation', model=vars.model)
else:
@ -1245,7 +1291,7 @@ def generate(txt, min, max):
# its first argument if we're using breakmodel, otherwise a string
# is fine
if(vars.hascuda and vars.breakmodel):
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.gpu_device)
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device)
else:
gen_in = txt

View File

@ -1,10 +1,10 @@
'''
This is a MODIFIED version of arrmansa's low VRAM patch.
https://github.com/arrmansa/Basic-UI-for-GPT-J-6B-with-low-vram/blob/main/GPT-J-6B-Low-Vram-UI.ipynb
The ORIGINAL version of the patch is released under the Apache License 2.0
Copyright 2021 arrmansa
Copyright 2021 finetuneanon
Copyright 2018 The Hugging Face team
Released under the Apache License 2.0
Apache License
@ -215,6 +215,8 @@ import torch
import torch.cuda.comm
import copy
import gc
import itertools
import bisect
from transformers.modeling_outputs import BaseModelOutputWithPast
@ -222,17 +224,9 @@ from transformers.utils import logging
logger = logging.get_logger(__name__)
class MaxSharedRamBlocksException(Exception):
def __init__(self, i: int):
self.corrected_max_shared_ram_blocks = i
super().__init__('max_shared_ram_blocks is set too high, please set it to '+str(i))
breakmodel = True
gpu_device = 'cuda'
total_blocks = 24
ram_blocks = 7
max_shared_ram_blocks = None
gpu_blocks = []
primary_device = 0
def new_forward(
@ -250,54 +244,43 @@ def new_forward(
return_dict=None,
embs=None,
):
global max_shared_ram_blocks
assert len(gpu_blocks) <= torch.cuda.device_count()
assert sum(gpu_blocks) <= len(self.h)
ram_blocks = len(self.h) - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
if breakmodel:
if max_shared_ram_blocks is None:
max_shared_ram_blocks = total_blocks
if not hasattr(self, 'extrastorage'):
setattr(self,"extrastorage",{})
torch.cuda.empty_cache()
for i in range(ram_blocks,len(self.h)):
self.h[i].to(gpu_device)
for i in range(ram_blocks):
self.h[i].to("cpu")
self.extrastorage[i] = copy.deepcopy(self.h[i])
smalltensor = torch.tensor(0).to(gpu_device)
smalltensor = torch.tensor(0).to(primary_device)
for param1 in self.h[i].parameters():
param1.data = smalltensor
self.h[i].to(gpu_device)
for i in range(len(self.h)):
for param in self.h[i].parameters():
param.requires_grad = False
param.data = param.data.detach()
gc.collect()
torch.cuda.empty_cache()
for i in range(ram_blocks):
self.h[i].to(primary_device)
for param in self.extrastorage[i].parameters():
param.requires_grad = False
if i < max_shared_ram_blocks:
try:
param.data = param.data.detach().pin_memory()
except:
raise MaxSharedRamBlocksException(i)
else:
param.data = param.data.detach()
gc.collect()
torch.cuda.empty_cache()
if ram_blocks:
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
param1.data = param2.data.to(gpu_device, non_blocking=False).detach()
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
param1.data = param2.data.to(gpu_device, non_blocking=False).detach()
#END MODEL BREAK EDITS
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
i = ram_blocks
for j in range(len(gpu_blocks)):
for _ in range(gpu_blocks[j]):
self.h[i].to(j)
i += 1
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -331,7 +314,7 @@ def new_forward(
else:
past_length = past_key_values[0][0].size(-2)
device = input_ids.device if input_ids is not None else inputs_embeds.device
device = primary_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
@ -368,6 +351,8 @@ def new_forward(
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
if inputs_embeds is None:
if breakmodel:
input_ids = input_ids.to(primary_device)
inputs_embeds = self.wte(input_ids)
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
@ -382,7 +367,11 @@ def new_forward(
if hasattr(self, 'rotary') and self.rotary:
hidden_states = inputs_embeds
else:
if breakmodel:
position_ids = position_ids.to(primary_device)
position_embeds = self.wpe(position_ids)
if breakmodel:
position_embeds = position_embeds.to(primary_device)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
@ -397,9 +386,8 @@ def new_forward(
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if breakmodel:
copystream = torch.cuda.Stream(device=0,priority = -1)
if breakmodel and ram_blocks:
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -412,7 +400,6 @@ def new_forward(
with torch.cuda.stream(copystream):
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
attn_type = self.config.attention_layers[i]
attn_mask = global_attention_mask
@ -443,11 +430,13 @@ def new_forward(
head_mask[i],
)
else:
if breakmodel:
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attn_mask,
head_mask=head_mask[i],
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask,
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
@ -466,12 +455,13 @@ def new_forward(
torch.cuda.empty_cache()
if breakmodel:
if ram_blocks:
del copystream
torch.cuda.empty_cache()
hidden_states = hidden_states.to(primary_device)
hidden_states = self.ln_f(hidden_states)
if breakmodel:
hidden_states = hidden_states.to(primary_device)
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
@ -480,7 +470,6 @@ def new_forward(
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,