mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-09 08:18:40 +01:00
Multiple GPU support
This commit is contained in:
parent
0937bb33e7
commit
a283d34b27
146
aiserver.py
146
aiserver.py
@ -178,6 +178,88 @@ def getmodelname():
|
||||
modelname = vars.model
|
||||
return modelname
|
||||
|
||||
#==================================================================#
|
||||
# Breakmodel configuration functions
|
||||
#==================================================================#
|
||||
def device_list(n_layers, primary=None, selected=None):
|
||||
device_count = torch.cuda.device_count()
|
||||
if(device_count < 2):
|
||||
primary = None
|
||||
gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0]
|
||||
print(f"{colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{colors.END}")
|
||||
for i in range(device_count):
|
||||
name = torch.cuda.get_device_name(i)
|
||||
if(len(name) > 47):
|
||||
name = "..." + name[-44:]
|
||||
row_color = colors.END
|
||||
sep_color = colors.YELLOW
|
||||
print(f"{row_color}{colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{colors.END}")
|
||||
row_color = colors.END
|
||||
sep_color = colors.YELLOW
|
||||
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
|
||||
|
||||
def device_config(model):
|
||||
global breakmodel, generator
|
||||
import breakmodel
|
||||
n_layers = model.config.num_layers
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
if(args.breakmodel_layers is not None):
|
||||
breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))]
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
if(device_count > 1):
|
||||
print(colors.CYAN + "\nPlease select one of your GPUs to be your primary GPU.")
|
||||
print("VRAM usage in your primary GPU will be higher than for your other ones.")
|
||||
print("It is recommended you make your fastest GPU your primary GPU.")
|
||||
device_list(n_layers)
|
||||
while(True):
|
||||
primaryselect = input("device ID> ")
|
||||
if(primaryselect.isnumeric() and 0 <= int(primaryselect) < device_count):
|
||||
breakmodel.primary_device = int(primaryselect)
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between 0 and {device_count-1}.{colors.END}")
|
||||
else:
|
||||
breakmodel.primary_device = 0
|
||||
|
||||
print(colors.PURPLE + "\nIf you don't have enough VRAM to run the model on a single GPU")
|
||||
print("you can split the model between your CPU and your GPU(s), or between")
|
||||
print("multiple GPUs if you have more than one.")
|
||||
print("By putting more 'layers' on a GPU or CPU, more computations will be")
|
||||
print("done on that device and more VRAM or RAM will be required on that device")
|
||||
print("(roughly proportional to number of layers).")
|
||||
print("It should be noted that GPUs are orders of magnitude faster than the CPU.")
|
||||
print(f"This model has{colors.YELLOW} {n_layers} {colors.PURPLE}layers.{colors.END}\n")
|
||||
|
||||
for i in range(device_count):
|
||||
device_list(n_layers, primary=breakmodel.primary_device, selected=i)
|
||||
print(f"{colors.CYAN}\nHow many of the remaining{colors.YELLOW} {n_layers} {colors.CYAN}layers would you like to put into device {i}?\nYou can also enter -1 to allocate all remaining layers to this device.{colors.END}\n")
|
||||
while(True):
|
||||
layerselect = input("# of layers> ")
|
||||
if((layerselect.isnumeric() or layerselect.strip() == '-1') and -1 <= int(layerselect) <= n_layers):
|
||||
layerselect = int(layerselect)
|
||||
layerselect = n_layers if layerselect == -1 else layerselect
|
||||
breakmodel.gpu_blocks.append(layerselect)
|
||||
n_layers -= layerselect
|
||||
break
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
|
||||
if(n_layers == 0):
|
||||
break
|
||||
|
||||
print(colors.PURPLE + "\nFinal device configuration:")
|
||||
device_list(n_layers)
|
||||
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
gc.collect()
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
generator = model.generate
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
#==================================================================#
|
||||
@ -414,36 +496,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(vars.usegpu):
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
import breakmodel
|
||||
n_layers = model.config.num_layers
|
||||
breakmodel.total_blocks = n_layers
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
model.transformer.wte.to(breakmodel.embedding_device)
|
||||
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.embedding_device)
|
||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||
model.transformer.wpe.to(breakmodel.positional_device)
|
||||
gc.collect()
|
||||
if(args.breakmodel_layers is not None):
|
||||
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
||||
else:
|
||||
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
|
||||
print("The more of them you put into system RAM, the slower it will run,")
|
||||
print("but it will require less VRAM")
|
||||
print("(roughly proportional to number of layers).")
|
||||
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
|
||||
while(True):
|
||||
layerselect = input("# of layers> ")
|
||||
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
|
||||
breakmodel.ram_blocks = int(layerselect)
|
||||
break
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
|
||||
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
generator = model.generate
|
||||
device_config(model)
|
||||
else:
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
||||
else:
|
||||
@ -465,37 +518,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(vars.usegpu):
|
||||
generator = pipeline('text-generation', model=vars.model, device=0)
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
import breakmodel
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
n_layers = model.config.num_layers
|
||||
breakmodel.total_blocks = n_layers
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
model.transformer.wte.to(breakmodel.embedding_device)
|
||||
model.transformer.ln_f.to(breakmodel.layernormfinal_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.embedding_device)
|
||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||
model.transformer.wpe.to(breakmodel.positional_device)
|
||||
gc.collect()
|
||||
if(args.breakmodel_layers is not None):
|
||||
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
||||
else:
|
||||
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
|
||||
print("The more of them you put into system RAM, the slower it will run,")
|
||||
print("but it will require less VRAM")
|
||||
print("(roughly proportional to number of layers).")
|
||||
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
|
||||
while(True):
|
||||
layerselect = input("# of layers> ")
|
||||
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
|
||||
breakmodel.ram_blocks = int(layerselect)
|
||||
break
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
|
||||
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
generator = model.generate
|
||||
device_config(model)
|
||||
else:
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
else:
|
||||
@ -1245,7 +1269,7 @@ def generate(txt, min, max):
|
||||
# its first argument if we're using breakmodel, otherwise a string
|
||||
# is fine
|
||||
if(vars.hascuda and vars.breakmodel):
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.embedding_device)
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device)
|
||||
else:
|
||||
gen_in = txt
|
||||
|
||||
|
@ -215,6 +215,8 @@ import torch
|
||||
import torch.cuda.comm
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import bisect
|
||||
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
@ -222,23 +224,9 @@ from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MaxSharedRamBlocksException(Exception):
|
||||
def __init__(self, i: int):
|
||||
self.corrected_max_shared_ram_blocks = i
|
||||
super().__init__('max_shared_ram_blocks is set too high, please set it to '+str(i))
|
||||
|
||||
|
||||
breakmodel = True
|
||||
devices = ['cpu', 'cuda']
|
||||
total_blocks = 24
|
||||
ram_blocks = 7
|
||||
max_shared_ram_blocks = None
|
||||
|
||||
# I highly suggest these all be set to the same device unless you really know what you're doing!
|
||||
# (They can all be set to any CPU or GPU device, except layernormfinal_device which can only be a GPU device)
|
||||
embedding_device = devices[1] # Dealing with text embedding is computationally expensive, I suggest you set this to your fastest device
|
||||
positional_device = devices[1] # Only used for GPT-Neo (not used for GPT-J)
|
||||
layernormfinal_device = devices[1] # This setting is unique in that this MUST be set to a GPU device, this cannot be set to 'cpu'
|
||||
gpu_blocks = []
|
||||
primary_device = 0
|
||||
|
||||
|
||||
def new_forward(
|
||||
@ -256,21 +244,41 @@ def new_forward(
|
||||
return_dict=None,
|
||||
embs=None,
|
||||
):
|
||||
global max_shared_ram_blocks
|
||||
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||
assert sum(gpu_blocks) <= len(self.h)
|
||||
ram_blocks = len(self.h) - sum(gpu_blocks)
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
|
||||
if breakmodel:
|
||||
if max_shared_ram_blocks is None:
|
||||
max_shared_ram_blocks = total_blocks
|
||||
|
||||
if not hasattr(self, 'extrastorage'):
|
||||
setattr(self,"extrastorage",{})
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for i in range(ram_blocks):
|
||||
self.h[i].to(devices[0])
|
||||
self.h[i].to("cpu")
|
||||
self.extrastorage[i] = copy.deepcopy(self.h[i])
|
||||
smalltensor = torch.tensor(0).to(primary_device)
|
||||
for param1 in self.h[i].parameters():
|
||||
param1.data = smalltensor
|
||||
self.h[i].to(primary_device)
|
||||
for param in self.extrastorage[i].parameters():
|
||||
param.requires_grad = False
|
||||
param.data = param.data.detach().pin_memory()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for i in range(ram_blocks,len(self.h)):
|
||||
self.h[i].to(devices[1])
|
||||
if ram_blocks:
|
||||
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
|
||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||
|
||||
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
|
||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||
|
||||
i = ram_blocks
|
||||
for j in range(len(gpu_blocks)):
|
||||
for _ in range(gpu_blocks[j]):
|
||||
self.h[i].to(j)
|
||||
i += 1
|
||||
|
||||
|
||||
|
||||
@ -306,7 +314,7 @@ def new_forward(
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
device = positional_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
device = primary_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
@ -344,7 +352,7 @@ def new_forward(
|
||||
|
||||
if inputs_embeds is None:
|
||||
if breakmodel:
|
||||
input_ids = input_ids.to(embedding_device)
|
||||
input_ids = input_ids.to(primary_device)
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
||||
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
|
||||
@ -360,10 +368,10 @@ def new_forward(
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if breakmodel:
|
||||
position_ids = position_ids.to(positional_device)
|
||||
position_ids = position_ids.to(primary_device)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if breakmodel:
|
||||
position_embeds = position_embeds.to(embedding_device)
|
||||
position_embeds = position_embeds.to(primary_device)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
@ -377,8 +385,21 @@ def new_forward(
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
if breakmodel and ram_blocks:
|
||||
copystream = torch.cuda.Stream(device=0,priority = -1)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
if breakmodel:
|
||||
if i in range(ram_blocks):
|
||||
index1 = (i+1)%ram_blocks
|
||||
for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()):
|
||||
param1.data = param2.data
|
||||
for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()):
|
||||
with torch.cuda.stream(copystream):
|
||||
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
||||
|
||||
attn_type = self.config.attention_layers[i]
|
||||
attn_mask = global_attention_mask
|
||||
|
||||
@ -410,7 +431,7 @@ def new_forward(
|
||||
)
|
||||
else:
|
||||
if breakmodel:
|
||||
device = devices[0] if i < ram_blocks else devices[1]
|
||||
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
|
||||
outputs = block(
|
||||
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None else layer_past,
|
||||
@ -428,11 +449,19 @@ def new_forward(
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
|
||||
if breakmodel:
|
||||
if i in range(ram_blocks):
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if breakmodel:
|
||||
hidden_states = hidden_states.to(layernormfinal_device)
|
||||
if ram_blocks:
|
||||
del copystream
|
||||
torch.cuda.empty_cache()
|
||||
hidden_states = hidden_states.to(primary_device)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
if breakmodel:
|
||||
hidden_states = hidden_states.to(embedding_device)
|
||||
hidden_states = hidden_states.to(primary_device)
|
||||
|
||||
hidden_states = hidden_states.view(*output_shape)
|
||||
# Add last hidden state
|
||||
|
Loading…
x
Reference in New Issue
Block a user