Merge pull request #39 from VE-FORBRYDERNE/breakmodel
Official transformers 6B breakmodel support and more RAM-efficient model loading
This commit is contained in:
commit
6008d4f3a5
52
aiserver.py
52
aiserver.py
|
@ -14,7 +14,9 @@ from tkinter import messagebox
|
||||||
import json
|
import json
|
||||||
import collections
|
import collections
|
||||||
import zipfile
|
import zipfile
|
||||||
from typing import Union, Dict, Set, List
|
import packaging
|
||||||
|
import contextlib
|
||||||
|
from typing import Any, Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import html
|
import html
|
||||||
|
@ -298,11 +300,14 @@ def device_config(model):
|
||||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||||
if(hasattr(model, 'lm_head')):
|
if(hasattr(model, 'lm_head')):
|
||||||
model.lm_head.to(breakmodel.primary_device)
|
model.lm_head.to(breakmodel.primary_device)
|
||||||
if(not hasattr(model.config, 'rotary') or not model.config.rotary):
|
if(hasattr(model.transformer, 'wpe')):
|
||||||
model.transformer.wpe.to(breakmodel.primary_device)
|
model.transformer.wpe.to(breakmodel.primary_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
GPTNeoModel.forward = breakmodel.new_forward
|
GPTNeoModel.forward = breakmodel.new_forward
|
||||||
|
if("GPTJModel" in globals()):
|
||||||
|
GPTJModel.forward = breakmodel.new_forward
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
|
breakmodel.move_hidden_layers(model.transformer)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Startup
|
# Startup
|
||||||
|
@ -537,7 +542,12 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||||
|
try:
|
||||||
|
from transformers import GPTJModel
|
||||||
|
except:
|
||||||
|
pass
|
||||||
import transformers.generation_utils
|
import transformers.generation_utils
|
||||||
|
from transformers import __version__ as transformers_version
|
||||||
|
|
||||||
# Patch transformers to use our soft prompt
|
# Patch transformers to use our soft prompt
|
||||||
def patch_causallm(cls):
|
def patch_causallm(cls):
|
||||||
|
@ -698,15 +708,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
return int(model.transformer.embed_dim)
|
return int(model.transformer.embed_dim)
|
||||||
except:
|
except:
|
||||||
return int(model.lm_head.in_features)
|
return int(model.lm_head.in_features)
|
||||||
|
|
||||||
|
def maybe_low_cpu_mem_usage() -> Dict[str, Any]:
|
||||||
|
if(packaging.version.parse(transformers_version) < packaging.version.parse("4.11.0")):
|
||||||
|
print(f"\nWARNING: Please upgrade to transformers 4.11.0 for lower RAM usage. You have transformers {transformers_version}.", file=sys.stderr)
|
||||||
|
return {}
|
||||||
|
return {"low_cpu_mem_usage": True}
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def maybe_use_float16(always_use=False):
|
||||||
|
if(always_use or (vars.hascuda and (vars.usegpu or vars.breakmodel))):
|
||||||
|
original_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(torch.float16)
|
||||||
|
yield True
|
||||||
|
torch.set_default_dtype(original_dtype)
|
||||||
|
else:
|
||||||
|
yield False
|
||||||
|
|
||||||
# If custom GPT Neo model was chosen
|
# If custom GPT Neo model was chosen
|
||||||
if(vars.model == "NeoCustom"):
|
if(vars.model == "NeoCustom"):
|
||||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||||
js = json.load(model_config)
|
js = json.load(model_config)
|
||||||
if("model_type" in js):
|
with(maybe_use_float16()):
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
if("model_type" in js):
|
||||||
else:
|
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
else:
|
||||||
|
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||||
|
@ -724,7 +751,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
elif(vars.model == "GPT2Custom"):
|
elif(vars.model == "GPT2Custom"):
|
||||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||||
js = json.load(model_config)
|
js = json.load(model_config)
|
||||||
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
with(maybe_use_float16()):
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||||
|
@ -739,20 +767,22 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/")
|
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/")
|
||||||
if(vars.hascuda):
|
if(vars.hascuda):
|
||||||
if(vars.usegpu):
|
if(vars.usegpu):
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/")
|
with(maybe_use_float16()):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
model = model.half().to(0)
|
model = model.half().to(0)
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/")
|
with(maybe_use_float16()):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
device_config(model)
|
device_config(model)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/")
|
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/")
|
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
|
|
||||||
|
|
105
breakmodel.py
105
breakmodel.py
|
@ -215,6 +215,7 @@ import torch
|
||||||
import torch.cuda.comm
|
import torch.cuda.comm
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
|
import sys
|
||||||
import itertools
|
import itertools
|
||||||
import bisect
|
import bisect
|
||||||
|
|
||||||
|
@ -229,6 +230,48 @@ gpu_blocks = []
|
||||||
primary_device = 0
|
primary_device = 0
|
||||||
|
|
||||||
|
|
||||||
|
def move_hidden_layers(transformer):
|
||||||
|
assert len(gpu_blocks) <= torch.cuda.device_count()
|
||||||
|
assert sum(gpu_blocks) <= len(transformer.h)
|
||||||
|
ram_blocks = len(transformer.h) - sum(gpu_blocks)
|
||||||
|
|
||||||
|
transformer.extrastorage = {}
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
able_to_pin_layers = True
|
||||||
|
for i in range(ram_blocks):
|
||||||
|
transformer.h[i].to("cpu")
|
||||||
|
transformer.extrastorage[i] = copy.deepcopy(transformer.h[i])
|
||||||
|
smalltensor = torch.tensor(0).to(primary_device)
|
||||||
|
for param1 in transformer.h[i].parameters():
|
||||||
|
param1.data = smalltensor
|
||||||
|
transformer.h[i].to(primary_device)
|
||||||
|
for param in transformer.extrastorage[i].parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
param.data = param.data.detach()
|
||||||
|
if able_to_pin_layers:
|
||||||
|
try:
|
||||||
|
param.data = param.data.pin_memory()
|
||||||
|
except:
|
||||||
|
able_to_pin_layers = False
|
||||||
|
print(f"WARNING: You only have enough shared GPU memory for {i} out of {ram_blocks} CPU layers. Expect suboptimal speed.", file=sys.stderr)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if ram_blocks:
|
||||||
|
for param1,param2 in zip(transformer.h[0].parameters(),transformer.extrastorage[0].parameters()):
|
||||||
|
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||||
|
|
||||||
|
for param1,param2 in zip(transformer.h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()):
|
||||||
|
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
||||||
|
|
||||||
|
i = ram_blocks
|
||||||
|
for j in range(len(gpu_blocks)):
|
||||||
|
for _ in range(gpu_blocks[j]):
|
||||||
|
transformer.h[i].to(j)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
def new_forward(
|
def new_forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
|
@ -249,38 +292,6 @@ def new_forward(
|
||||||
ram_blocks = len(self.h) - sum(gpu_blocks)
|
ram_blocks = len(self.h) - sum(gpu_blocks)
|
||||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||||
|
|
||||||
if breakmodel:
|
|
||||||
if not hasattr(self, 'extrastorage'):
|
|
||||||
setattr(self,"extrastorage",{})
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
for i in range(ram_blocks):
|
|
||||||
self.h[i].to("cpu")
|
|
||||||
self.extrastorage[i] = copy.deepcopy(self.h[i])
|
|
||||||
smalltensor = torch.tensor(0).to(primary_device)
|
|
||||||
for param1 in self.h[i].parameters():
|
|
||||||
param1.data = smalltensor
|
|
||||||
self.h[i].to(primary_device)
|
|
||||||
for param in self.extrastorage[i].parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
param.data = param.data.detach().pin_memory()
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if ram_blocks:
|
|
||||||
for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
|
|
||||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
|
||||||
|
|
||||||
for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()):
|
|
||||||
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
|
|
||||||
|
|
||||||
i = ram_blocks
|
|
||||||
for j in range(len(gpu_blocks)):
|
|
||||||
for _ in range(gpu_blocks[j]):
|
|
||||||
self.h[i].to(j)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -322,33 +333,27 @@ def new_forward(
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert batch_size > 0, "batch_size has to be defined and > 0"
|
assert batch_size > 0, "batch_size has to be defined and > 0"
|
||||||
global_attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
global_attention_mask = global_attention_mask[:, None, None, :]
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
# Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
global_attention_mask = (1.0 - global_attention_mask) * -10000.0
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||||
else:
|
|
||||||
global_attention_mask = None
|
|
||||||
|
|
||||||
# Local causal attention mask
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
full_seq_length = seq_length + past_length
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x num_heads x N x N
|
# attention_probs has shape bsz x num_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
head_mask = self.get_head_mask(head_mask, getattr(self.config, "num_layers", None) or self.config.n_layer)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
if breakmodel:
|
if breakmodel:
|
||||||
|
@ -364,7 +369,7 @@ def new_forward(
|
||||||
inputs_embeds[:, pos:pos+emb.shape[1]] = emb
|
inputs_embeds[:, pos:pos+emb.shape[1]] = emb
|
||||||
offset += emb.shape[1]
|
offset += emb.shape[1]
|
||||||
|
|
||||||
if hasattr(self, 'rotary') and self.rotary:
|
if getattr(self, "wpe", None) is None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
if breakmodel:
|
if breakmodel:
|
||||||
|
@ -400,9 +405,6 @@ def new_forward(
|
||||||
with torch.cuda.stream(copystream):
|
with torch.cuda.stream(copystream):
|
||||||
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
|
||||||
|
|
||||||
attn_type = self.config.attention_layers[i]
|
|
||||||
attn_mask = global_attention_mask
|
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states.cpu(),)
|
all_hidden_states = all_hidden_states + (hidden_states.cpu(),)
|
||||||
|
|
||||||
|
@ -410,8 +412,7 @@ def new_forward(
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
"`use_cache=False`..."
|
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
@ -426,7 +427,7 @@ def new_forward(
|
||||||
create_custom_forward(block),
|
create_custom_forward(block),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
attn_mask,
|
attention_mask,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -435,7 +436,7 @@ def new_forward(
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
|
||||||
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
|
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
|
||||||
attention_mask=attn_mask.to(device) if breakmodel and attn_mask is not None else attn_mask,
|
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
|
||||||
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
|
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
|
Loading…
Reference in New Issue