mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-15 11:10:36 +01:00
Merge pull request #19 from VE-FORBRYDERNE/multi-gpu
Multiple GPU support
This commit is contained in:
commit
bd063f7590
202
aiserver.py
202
aiserver.py
@ -178,6 +178,99 @@ def getmodelname():
|
||||
modelname = vars.model
|
||||
return modelname
|
||||
|
||||
#==================================================================#
|
||||
# Breakmodel configuration functions
|
||||
#==================================================================#
|
||||
def device_list(n_layers, primary=None, selected=None):
|
||||
device_count = torch.cuda.device_count()
|
||||
if(device_count < 2):
|
||||
primary = None
|
||||
gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0]
|
||||
print(f"{colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{colors.END}")
|
||||
for i in range(device_count):
|
||||
name = torch.cuda.get_device_name(i)
|
||||
if(len(name) > 47):
|
||||
name = "..." + name[-44:]
|
||||
row_color = colors.END
|
||||
sep_color = colors.YELLOW
|
||||
print(f"{row_color}{colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{colors.END}")
|
||||
row_color = colors.END
|
||||
sep_color = colors.YELLOW
|
||||
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
|
||||
|
||||
def device_config(model):
|
||||
global breakmodel, generator
|
||||
import breakmodel
|
||||
n_layers = model.config.num_layers
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
if(args.layers is not None):
|
||||
try:
|
||||
breakmodel.gpu_blocks = list(map(int, args.layers.split(',')))
|
||||
assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count()
|
||||
assert sum(breakmodel.gpu_blocks) <= n_layers
|
||||
n_layers -= sum(breakmodel.gpu_blocks)
|
||||
except:
|
||||
print("WARNING: --layers is malformatted. Please use the --help option to see correct usage of --layers. Defaulting to all layers on device 0.", file=sys.stderr)
|
||||
breakmodel.gpu_blocks = [n_layers]
|
||||
n_layers = 0
|
||||
elif(args.breakmodel_layers is not None):
|
||||
breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))]
|
||||
n_layers -= sum(breakmodel.gpu_blocks)
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
if(device_count > 1):
|
||||
print(colors.CYAN + "\nPlease select one of your GPUs to be your primary GPU.")
|
||||
print("VRAM usage in your primary GPU will be higher than for your other ones.")
|
||||
print("It is recommended you make your fastest GPU your primary GPU.")
|
||||
device_list(n_layers)
|
||||
while(True):
|
||||
primaryselect = input("device ID> ")
|
||||
if(primaryselect.isnumeric() and 0 <= int(primaryselect) < device_count):
|
||||
breakmodel.primary_device = int(primaryselect)
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between 0 and {device_count-1}.{colors.END}")
|
||||
else:
|
||||
breakmodel.primary_device = 0
|
||||
|
||||
print(colors.PURPLE + "\nIf you don't have enough VRAM to run the model on a single GPU")
|
||||
print("you can split the model between your CPU and your GPU(s), or between")
|
||||
print("multiple GPUs if you have more than one.")
|
||||
print("By putting more 'layers' on a GPU or CPU, more computations will be")
|
||||
print("done on that device and more VRAM or RAM will be required on that device")
|
||||
print("(roughly proportional to number of layers).")
|
||||
print("It should be noted that GPUs are orders of magnitude faster than the CPU.")
|
||||
print(f"This model has{colors.YELLOW} {n_layers} {colors.PURPLE}layers.{colors.END}\n")
|
||||
|
||||
for i in range(device_count):
|
||||
device_list(n_layers, primary=breakmodel.primary_device, selected=i)
|
||||
print(f"{colors.CYAN}\nHow many of the remaining{colors.YELLOW} {n_layers} {colors.CYAN}layers would you like to put into device {i}?\nYou can also enter -1 to allocate all remaining layers to this device.{colors.END}\n")
|
||||
while(True):
|
||||
layerselect = input("# of layers> ")
|
||||
if((layerselect.isnumeric() or layerselect.strip() == '-1') and -1 <= int(layerselect) <= n_layers):
|
||||
layerselect = int(layerselect)
|
||||
layerselect = n_layers if layerselect == -1 else layerselect
|
||||
breakmodel.gpu_blocks.append(layerselect)
|
||||
n_layers -= layerselect
|
||||
break
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
|
||||
if(n_layers == 0):
|
||||
break
|
||||
|
||||
print(colors.PURPLE + "\nFinal device configuration:")
|
||||
device_list(n_layers)
|
||||
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
gc.collect()
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
generator = model.generate
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
#==================================================================#
|
||||
@ -188,8 +281,9 @@ parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI fo
|
||||
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
||||
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
|
||||
parser.add_argument("--cpu", action='store_true', help="By default unattended launches are on the GPU use this option to force CPU usage.")
|
||||
parser.add_argument("--breakmodel", action='store_true', help="For models that support GPU-CPU hybrid generation, use this feature instead of GPU or CPU generation")
|
||||
parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used")
|
||||
parser.add_argument("--breakmodel", action='store_true', help=argparse.SUPPRESS)
|
||||
parser.add_argument("--breakmodel_layers", type=int, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--layers", type=str, help="If using a model that supports hybrid generation, this is a comma-separated list that specifies how many layers to put on each GPU device. For example to put 8 layers on device 0, 9 layers on device 1 and 11 layers on device 2, use --layers 8,9,11")
|
||||
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
|
||||
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
|
||||
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
||||
@ -222,6 +316,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = vars.model in ("EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", "NeoCustom")
|
||||
if(args.breakmodel is not None and args.breakmodel):
|
||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr)
|
||||
if(args.breakmodel_layers is not None):
|
||||
print("WARNING: --breakmodel_layers is deprecated. Use --layers instead (see --help for details).", file=sys.stderr)
|
||||
if(not vars.bmsupported and (args.layers is not None or args.breakmodel_layers is not None)):
|
||||
print("WARNING: This model does not support hybrid generation. --layers will be ignored.", file=sys.stderr)
|
||||
if(vars.hascuda):
|
||||
print("{0}FOUND!{1}".format(colors.GREEN, colors.END))
|
||||
else:
|
||||
@ -232,18 +332,21 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
genselected = True
|
||||
vars.usegpu = True
|
||||
vars.breakmodel = False
|
||||
if(vars.bmsupported):
|
||||
vars.usegpu = False
|
||||
vars.breakmodel = True
|
||||
if(args.cpu):
|
||||
vars.usegpu = False
|
||||
vars.breakmodel = False
|
||||
if(vars.bmsupported and args.breakmodel):
|
||||
vars.usegpu = False
|
||||
vars.breakmodel = True
|
||||
elif(vars.hascuda):
|
||||
if(vars.bmsupported):
|
||||
print(colors.YELLOW + "You're using a model that supports GPU-CPU hybrid generation!\nCurrently only GPT-Neo models and GPT-J-6B support this feature.")
|
||||
print("{0}Use GPU or CPU for generation?: (Default GPU){1}".format(colors.CYAN, colors.END))
|
||||
if(vars.bmsupported):
|
||||
print(f" 1 - GPU\n 2 - CPU\n 3 - Both (slower than GPU-only but uses less VRAM)\n")
|
||||
print(colors.YELLOW + "You're using a model that supports hybrid generation!")
|
||||
print("This feature allows you to split the model between the CPU and GPU(s)")
|
||||
print("(slower than GPU-only but uses less VRAM) or between multiple GPUs")
|
||||
print("(allowing you to use the combined VRAM of all your GPUs).")
|
||||
print("Currently only GPT-Neo and GPT-J models support this feature.")
|
||||
print("{0}Use hybrid generation or CPU-only generation?: (Default hybrid){1}".format(colors.CYAN, colors.END))
|
||||
print(f" 1 - Hybrid generation\n 2 - CPU\n")
|
||||
else:
|
||||
print(" 1 - GPU\n 2 - CPU\n")
|
||||
genselected = False
|
||||
@ -256,17 +359,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
vars.usegpu = True
|
||||
genselected = True
|
||||
elif(genselect.isnumeric() and int(genselect) == 1):
|
||||
vars.breakmodel = False
|
||||
vars.usegpu = True
|
||||
genselected = True
|
||||
if(vars.bmsupported):
|
||||
vars.breakmodel = True
|
||||
vars.usegpu = False
|
||||
genselected = True
|
||||
else:
|
||||
vars.breakmodel = False
|
||||
vars.usegpu = True
|
||||
genselected = True
|
||||
elif(genselect.isnumeric() and int(genselect) == 2):
|
||||
vars.breakmodel = False
|
||||
vars.usegpu = False
|
||||
genselected = True
|
||||
elif(vars.bmsupported and genselect.isnumeric() and int(genselect) == 3):
|
||||
vars.breakmodel = True
|
||||
vars.usegpu = False
|
||||
genselected = True
|
||||
else:
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||
|
||||
@ -405,7 +509,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
js = json.load(model_config)
|
||||
if("model_type" in js):
|
||||
model = vars.custmodpth
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth)
|
||||
else:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
|
||||
@ -414,36 +518,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(vars.usegpu):
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
import breakmodel
|
||||
n_layers = model.config.num_layers
|
||||
breakmodel.total_blocks = n_layers
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
model.transformer.wte.to(breakmodel.gpu_device)
|
||||
model.transformer.ln_f.to(breakmodel.gpu_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.gpu_device)
|
||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||
model.transformer.wpe.to(breakmodel.gpu_device)
|
||||
gc.collect()
|
||||
if(args.breakmodel_layers is not None):
|
||||
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
||||
else:
|
||||
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
|
||||
print("The more of them you put into system RAM, the slower it will run,")
|
||||
print("but it will require less VRAM")
|
||||
print("(roughly proportional to number of layers).")
|
||||
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
|
||||
while(True):
|
||||
layerselect = input("# of layers> ")
|
||||
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
|
||||
breakmodel.ram_blocks = int(layerselect)
|
||||
break
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
|
||||
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
generator = model.generate
|
||||
device_config(model)
|
||||
else:
|
||||
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
||||
else:
|
||||
@ -465,37 +540,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(vars.usegpu):
|
||||
generator = pipeline('text-generation', model=vars.model, device=0)
|
||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||
import breakmodel
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model)
|
||||
n_layers = model.config.num_layers
|
||||
breakmodel.total_blocks = n_layers
|
||||
model.half().to('cpu')
|
||||
gc.collect()
|
||||
model.transformer.wte.to(breakmodel.gpu_device)
|
||||
model.transformer.ln_f.to(breakmodel.gpu_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.gpu_device)
|
||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
||||
model.transformer.wpe.to(breakmodel.gpu_device)
|
||||
gc.collect()
|
||||
if(args.breakmodel_layers is not None):
|
||||
breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers))
|
||||
else:
|
||||
print(colors.CYAN + "\nHow many layers would you like to put into system RAM?")
|
||||
print("The more of them you put into system RAM, the slower it will run,")
|
||||
print("but it will require less VRAM")
|
||||
print("(roughly proportional to number of layers).")
|
||||
print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n")
|
||||
while(True):
|
||||
layerselect = input("# of layers> ")
|
||||
if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers):
|
||||
breakmodel.ram_blocks = int(layerselect)
|
||||
break
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}")
|
||||
print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}")
|
||||
GPTNeoModel.forward = breakmodel.new_forward
|
||||
generator = model.generate
|
||||
device_config(model)
|
||||
else:
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
else:
|
||||
@ -1245,7 +1291,7 @@ def generate(txt, min, max):
|
||||
# its first argument if we're using breakmodel, otherwise a string
|
||||
# is fine
|
||||
if(vars.hascuda and vars.breakmodel):
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.gpu_device)
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.primary_device)
|
||||
else:
|
||||
gen_in = txt
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
'''
|
||||
This is a MODIFIED version of arrmansa's low VRAM patch.
|
||||
https://github.com/arrmansa/Basic-UI-for-GPT-J-6B-with-low-vram/blob/main/GPT-J-6B-Low-Vram-UI.ipynb
|
||||
The ORIGINAL version of the patch is released under the Apache License 2.0
|
||||
Copyright 2021 arrmansa
|
||||
Copyright 2021 finetuneanon
|
||||
Copyright 2018 The Hugging Face team
|
||||
Released under the Apache License 2.0
|
||||
|
||||
|
||||
Apache License
|
||||
@ -215,6 +215,8 @@ import torch
|
||||
import torch.cuda.comm
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import bisect
|
||||
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
@ -222,17 +224,9 @@ from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MaxSharedRamBlocksException(Exception):
|
||||
def __init__(self, i: int):
|
||||
self.corrected_max_shared_ram_blocks = i
|
||||
super().__init__('max_shared_ram_blocks is set too high, please set it to '+str(i))
|
||||
|
||||
|
||||
breakmodel = True
|
||||
gpu_device = 'cuda'
|
||||
total_blocks = 24
|
||||
ram_blocks = 7
|
||||
max_shared_ram_blocks = None
|
||||
gpu_blocks = []
|
||||
primary_device = 0
|
||||
|
||||
|
||||
def new_forward(
|
||||
@ -250,54 +244,43 @@ def new_forward(
|
||||
return_dict=None,
|
||||
embs=None,
|
||||
):
|
||||
global max_shared_ram_blocks
|
||||
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||
assert sum(gpu_blocks) <= len(self.h)
|
||||
ram_blocks = len(self.h) - sum(gpu_blocks)
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
|
||||
if breakmodel:
|
||||
if max_shared_ram_blocks is None:
|
||||
max_shared_ram_blocks = total_blocks
|
||||
|
||||
if not hasattr(self, 'extrastorage'):
|
||||
setattr(self,"extrastorage",{})
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for i in range(ram_blocks,len(self.h)):
|
||||
self.h[i].to(gpu_device)
|
||||
|
||||
for i in range(ram_blocks):
|
||||
self.h[i].to("cpu")
|
||||
self.extrastorage[i] = copy.deepcopy(self.h[i])
|
||||
smalltensor = torch.tensor(0).to(gpu_device)
|
||||
smalltensor = torch.tensor(0).to(primary_device)
|
||||
for param1 in self.h[i].parameters():
|
||||
param1.data = smalltensor
|
||||
self.h[i].to(gpu_device)
|
||||
|
||||
for i in range(len(self.h)):
|
||||
for param in self.h[i].parameters():
|
||||
param.requires_grad = False
|
||||
param.data = param.data.detach()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for i in range(ram_blocks):
|
||||
self.h[i].to(primary_device)
|
||||
for param in self.extrastorage[i].parameters():
|
||||
param.requires_grad = False
|
||||
if i < max_shared_ram_blocks:
|
||||
try:
|
||||
param.data = param.data.detach().pin_memory()
|
||||
except:
|
||||
raise MaxSharedRamBlocksException(i)
|
||||
else:
|
||||
param.data = param.data.detach()
|
||||
param.data = param.data.detach().pin_memory()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if ram_blocks:
|
||||
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
|
||||
param1.data = param2.data.to(gpu_device, non_blocking=False).detach()
|
||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||
|
||||
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
|
||||
param1.data = param2.data.to(gpu_device, non_blocking=False).detach()
|
||||
#END MODEL BREAK EDITS
|
||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||
|
||||
i = ram_blocks
|
||||
for j in range(len(gpu_blocks)):
|
||||
for _ in range(gpu_blocks[j]):
|
||||
self.h[i].to(j)
|
||||
i += 1
|
||||
|
||||
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -331,7 +314,7 @@ def new_forward(
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
device = primary_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
@ -368,6 +351,8 @@ def new_forward(
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
if breakmodel:
|
||||
input_ids = input_ids.to(primary_device)
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
||||
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
|
||||
@ -382,7 +367,11 @@ def new_forward(
|
||||
if hasattr(self, 'rotary') and self.rotary:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if breakmodel:
|
||||
position_ids = position_ids.to(primary_device)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if breakmodel:
|
||||
position_embeds = position_embeds.to(primary_device)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
@ -397,9 +386,8 @@ def new_forward(
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
|
||||
if breakmodel:
|
||||
copystream = torch.cuda.Stream(device=0,priority = -1)
|
||||
if breakmodel and ram_blocks:
|
||||
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
@ -412,7 +400,6 @@ def new_forward(
|
||||
with torch.cuda.stream(copystream):
|
||||
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
||||
|
||||
|
||||
attn_type = self.config.attention_layers[i]
|
||||
attn_mask = global_attention_mask
|
||||
|
||||
@ -443,11 +430,13 @@ def new_forward(
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
if breakmodel:
|
||||
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attn_mask,
|
||||
head_mask=head_mask[i],
|
||||
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
|
||||
attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask,
|
||||
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -466,12 +455,13 @@ def new_forward(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if breakmodel:
|
||||
del copystream
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if ram_blocks:
|
||||
del copystream
|
||||
torch.cuda.empty_cache()
|
||||
hidden_states = hidden_states.to(primary_device)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
if breakmodel:
|
||||
hidden_states = hidden_states.to(primary_device)
|
||||
|
||||
hidden_states = hidden_states.view(*output_shape)
|
||||
# Add last hidden state
|
||||
@ -480,7 +470,6 @@ def new_forward(
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
|
Loading…
x
Reference in New Issue
Block a user