Merge pull request #39 from VE-FORBRYDERNE/breakmodel

Official transformers 6B breakmodel support and more RAM-efficient model loading
This commit is contained in:
henk717 2021-11-27 01:11:48 +01:00 committed by GitHub
commit 6008d4f3a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 63 deletions

View File

@ -14,7 +14,9 @@ from tkinter import messagebox
import json import json
import collections import collections
import zipfile import zipfile
from typing import Union, Dict, Set, List import packaging
import contextlib
from typing import Any, Union, Dict, Set, List
import requests import requests
import html import html
@ -298,11 +300,14 @@ def device_config(model):
model.transformer.ln_f.to(breakmodel.primary_device) model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')): if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device) model.lm_head.to(breakmodel.primary_device)
if(not hasattr(model.config, 'rotary') or not model.config.rotary): if(hasattr(model.transformer, 'wpe')):
model.transformer.wpe.to(breakmodel.primary_device) model.transformer.wpe.to(breakmodel.primary_device)
gc.collect() gc.collect()
GPTNeoModel.forward = breakmodel.new_forward GPTNeoModel.forward = breakmodel.new_forward
if("GPTJModel" in globals()):
GPTJModel.forward = breakmodel.new_forward
generator = model.generate generator = model.generate
breakmodel.move_hidden_layers(model.transformer)
#==================================================================# #==================================================================#
# Startup # Startup
@ -537,7 +542,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(not vars.noai): if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
try:
from transformers import GPTJModel
except:
pass
import transformers.generation_utils import transformers.generation_utils
from transformers import __version__ as transformers_version
# Patch transformers to use our soft prompt # Patch transformers to use our soft prompt
def patch_causallm(cls): def patch_causallm(cls):
@ -698,15 +708,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
return int(model.transformer.embed_dim) return int(model.transformer.embed_dim)
except: except:
return int(model.lm_head.in_features) return int(model.lm_head.in_features)
def maybe_low_cpu_mem_usage() -> Dict[str, Any]:
if(packaging.version.parse(transformers_version) < packaging.version.parse("4.11.0")):
print(f"\nWARNING: Please upgrade to transformers 4.11.0 for lower RAM usage. You have transformers {transformers_version}.", file=sys.stderr)
return {}
return {"low_cpu_mem_usage": True}
@contextlib.contextmanager
def maybe_use_float16(always_use=False):
if(always_use or (vars.hascuda and (vars.usegpu or vars.breakmodel))):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float16)
yield True
torch.set_default_dtype(original_dtype)
else:
yield False
# If custom GPT Neo model was chosen # If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"): if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r") model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config) js = json.load(model_config)
if("model_type" in js): with(maybe_use_float16()):
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/") if("model_type" in js):
else: model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/") else:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
# Is CUDA available? If so, use GPU, otherwise fall back to CPU # Is CUDA available? If so, use GPU, otherwise fall back to CPU
@ -724,7 +751,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
elif(vars.model == "GPT2Custom"): elif(vars.model == "GPT2Custom"):
model_config = open(vars.custmodpth + "/config.json", "r") model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config) js = json.load(model_config)
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/") with(maybe_use_float16()):
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)
# Is CUDA available? If so, use GPU, otherwise fall back to CPU # Is CUDA available? If so, use GPU, otherwise fall back to CPU
@ -739,20 +767,22 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/") tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/")
if(vars.hascuda): if(vars.hascuda):
if(vars.usegpu): if(vars.usegpu):
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/") with(maybe_use_float16()):
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)
model = model.half().to(0) model = model.half().to(0)
generator = model.generate generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/") with(maybe_use_float16()):
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)
device_config(model) device_config(model)
else: else:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/") model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)
generator = model.generate generator = model.generate
else: else:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/") model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)
generator = model.generate generator = model.generate

View File

@ -215,6 +215,7 @@ import torch
import torch.cuda.comm import torch.cuda.comm
import copy import copy
import gc import gc
import sys
import itertools import itertools
import bisect import bisect
@ -229,6 +230,48 @@ gpu_blocks = []
primary_device = 0 primary_device = 0
def move_hidden_layers(transformer):
assert len(gpu_blocks) <= torch.cuda.device_count()
assert sum(gpu_blocks) <= len(transformer.h)
ram_blocks = len(transformer.h) - sum(gpu_blocks)
transformer.extrastorage = {}
torch.cuda.empty_cache()
able_to_pin_layers = True
for i in range(ram_blocks):
transformer.h[i].to("cpu")
transformer.extrastorage[i] = copy.deepcopy(transformer.h[i])
smalltensor = torch.tensor(0).to(primary_device)
for param1 in transformer.h[i].parameters():
param1.data = smalltensor
transformer.h[i].to(primary_device)
for param in transformer.extrastorage[i].parameters():
param.requires_grad = False
param.data = param.data.detach()
if able_to_pin_layers:
try:
param.data = param.data.pin_memory()
except:
able_to_pin_layers = False
print(f"WARNING: You only have enough shared GPU memory for {i} out of {ram_blocks} CPU layers. Expect suboptimal speed.", file=sys.stderr)
gc.collect()
torch.cuda.empty_cache()
if ram_blocks:
for param1,param2 in zip(transformer.h[0].parameters(),transformer.extrastorage[0].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
for param1,param2 in zip(transformer.h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
i = ram_blocks
for j in range(len(gpu_blocks)):
for _ in range(gpu_blocks[j]):
transformer.h[i].to(j)
i += 1
def new_forward( def new_forward(
self, self,
input_ids=None, input_ids=None,
@ -249,38 +292,6 @@ def new_forward(
ram_blocks = len(self.h) - sum(gpu_blocks) ram_blocks = len(self.h) - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
if breakmodel:
if not hasattr(self, 'extrastorage'):
setattr(self,"extrastorage",{})
torch.cuda.empty_cache()
for i in range(ram_blocks):
self.h[i].to("cpu")
self.extrastorage[i] = copy.deepcopy(self.h[i])
smalltensor = torch.tensor(0).to(primary_device)
for param1 in self.h[i].parameters():
param1.data = smalltensor
self.h[i].to(primary_device)
for param in self.extrastorage[i].parameters():
param.requires_grad = False
param.data = param.data.detach().pin_memory()
gc.collect()
torch.cuda.empty_cache()
if ram_blocks:
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
i = ram_blocks
for j in range(len(gpu_blocks)):
for _ in range(gpu_blocks[j]):
self.h[i].to(j)
i += 1
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -322,33 +333,27 @@ def new_forward(
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0" assert batch_size > 0, "batch_size has to be defined and > 0"
global_attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
global_attention_mask = global_attention_mask[:, None, None, :] attention_mask = attention_mask[:, None, None, :]
# Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
global_attention_mask = (1.0 - global_attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * -10000.0
else:
global_attention_mask = None
# Local causal attention mask
batch_size, seq_length = input_shape
full_seq_length = seq_length + past_length
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_heads x N x N # attention_probs has shape bsz x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, getattr(self.config, "num_layers", None) or self.config.n_layer)
if inputs_embeds is None: if inputs_embeds is None:
if breakmodel: if breakmodel:
@ -364,7 +369,7 @@ def new_forward(
inputs_embeds[:, pos:pos+emb.shape[1]] = emb inputs_embeds[:, pos:pos+emb.shape[1]] = emb
offset += emb.shape[1] offset += emb.shape[1]
if hasattr(self, 'rotary') and self.rotary: if getattr(self, "wpe", None) is None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
if breakmodel: if breakmodel:
@ -400,9 +405,6 @@ def new_forward(
with torch.cuda.stream(copystream): with torch.cuda.stream(copystream):
torch.cuda.comm.broadcast(param2.data,out = [param1.data]) torch.cuda.comm.broadcast(param2.data,out = [param1.data])
attn_type = self.config.attention_layers[i]
attn_mask = global_attention_mask
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.cpu(),) all_hidden_states = all_hidden_states + (hidden_states.cpu(),)
@ -410,8 +412,7 @@ def new_forward(
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
@ -426,7 +427,7 @@ def new_forward(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
None, None,
attn_mask, attention_mask,
head_mask[i], head_mask[i],
) )
else: else:
@ -435,7 +436,7 @@ def new_forward(
outputs = block( outputs = block(
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states, hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past, layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask, attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i], head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,