Merge pull request #155 from VE-FORBRYDERNE/accelerate

Initial support for Accelerate
This commit is contained in:
henk717 2022-06-20 01:08:54 +02:00 committed by GitHub
commit efed44ac8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 168 additions and 82 deletions

View File

@ -610,6 +610,24 @@ def move_model_to_devices(model):
model.half()
gc.collect()
if(utils.HAS_ACCELERATE):
import accelerate
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = len(vars.layers_module_names) - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
device_map = {}
for name in vars.layers_module_names:
layer = int(name.rsplit(".", 1)[1])
device = "cpu" if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[name] = device
for name in utils.get_missing_module_names(model, list(device_map.keys())):
device_map[name] = breakmodel.primary_device
accelerate.dispatch_model(model, device_map, main_device=breakmodel.primary_device)
gc.collect()
generator = model.generate
return
if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
@ -1192,8 +1210,37 @@ def get_oai_models(key):
print("{0}ERROR!{1}".format(colors.RED, colors.END))
print(req.json())
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
# Function to patch transformers to use our soft prompt
def patch_causallm(cls):
if(getattr(cls, "_koboldai_patch_causallm_patched", False)):
return
old_forward = cls.forward
def new_causallm_forward(self, *args, **kwargs):
input_ids = kwargs.get('input_ids').to(self.device)
assert input_ids is not None
kwargs['input_ids'] = None
if(vars.sp is not None):
shifted_input_ids = input_ids - self.config.vocab_size
input_ids.clamp_(max=self.config.vocab_size-1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
inputs_embeds *= self.model.embed_scale
kwargs['inputs_embeds'] = inputs_embeds
return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward
cls._koboldai_patch_causallm_patched = True
return cls
def patch_transformers():
global transformers
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@ -1241,42 +1288,6 @@ def patch_transformers():
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
XGLMSinusoidalPositionalEmbedding.forward = new_forward
# Patch transformers to use our soft prompt
def patch_causallm(cls):
old_forward = cls.forward
def new_causallm_forward(self, *args, **kwargs):
input_ids = kwargs.get('input_ids').to(self.device)
assert input_ids is not None
kwargs['input_ids'] = None
if(vars.sp is not None):
shifted_input_ids = input_ids - self.config.vocab_size
input_ids.clamp_(max=self.config.vocab_size-1)
if(hasattr(self, "transformer")):
inputs_embeds = self.transformer.wte(input_ids)
elif(not hasattr(self.model, "decoder")):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
inputs_embeds = self.model.decoder.embed_tokens(input_ids)
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
inputs_embeds *= self.model.embed_scale
kwargs['inputs_embeds'] = inputs_embeds
return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
patch_causallm(cls)
for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"):
try:
patch_causallm(getattr(__import__("transformers"), c))
except:
pass
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")):
@ -1563,7 +1574,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
vars.hascuda = torch.cuda.is_available()
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm", "opt") and not vars.nobreakmodel
vars.bmsupported = (utils.HAS_ACCELERATE or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel
if(args.breakmodel is not None and args.breakmodel):
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
if(args.breakmodel_layers is not None):
@ -1657,24 +1668,20 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
else:
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
def lazy_load_callback(model_dict, f, **_):
def lazy_load_callback(model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]], f, **_):
if lazy_load_callback.nested:
return
lazy_load_callback.nested = True
device_map = {}
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(n_layers):
key = _key.format(layer=layer)
if key not in model_dict:
continue
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[key] = device
device_map: Dict[str, Union[str, int]] = {}
for key, value in model_dict.items():
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(key.startswith(n) or key.startswith(n.split(".", 1)[1]) for n in vars.layers_module_names):
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel else breakmodel.primary_device
else:
layer = int(max((n for n in vars.layers_module_names if key.startswith(n) or key.startswith(n.split(".", 1)[1])), key=len).rsplit(".", 1)[1])
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel else "shared" if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[key] = device
if utils.num_shards is None or utils.current_shard == 0:
if utils.num_shards is not None:
@ -1689,6 +1696,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
last_storage_key = None
f = None
current_offset = 0
able_to_pin_layers = True
if utils.num_shards is not None:
utils.current_shard += 1
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
@ -1714,7 +1722,15 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
model_dict[key] = model_dict[key].to(torch.float16)
if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16:
model_dict[key] = model_dict[key].to(torch.float32)
model_dict[key] = model_dict[key].to(device)
if device == "shared":
model_dict[key] = model_dict[key].to("cpu").detach_()
if able_to_pin_layers and utils.HAS_ACCELERATE:
try:
model_dict[key] = model_dict[key].pin_memory()
except:
able_to_pin_layers = False
else:
model_dict[key] = model_dict[key].to(device)
#print("OK", flush=True)
current_offset += nbytes
utils.bar.update(1)
@ -1729,15 +1745,6 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
lazy_load_callback.nested = False
return lazy_load_callback
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)):
with open(lazy_load_config_path) as f:
lazy_load_spec = json.load(f)
else:
vars.lazy_load = False
def get_hidden_size_from_model(model):
try:
@ -1791,6 +1798,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
else:
model = model.to('cpu').float()
generator = model.generate
patch_causallm(model.__class__)
# Use the Generic implementation
else:
lowmem = maybe_low_cpu_mem_usage()
@ -1799,6 +1807,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
# feature yet
if(vars.model_type == "gpt2"):
lowmem = {}
vars.lazy_load = False # Also, lazy loader doesn't support GPT-2 models
# If we're using torch_lazy_loader, we need to get breakmodel config
# early so that it knows where to load the individual model tensors
@ -1812,6 +1821,13 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True)
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
try:
metamodel = AutoModelForCausalLM.from_config(model_config)
except Exception as e:
metamodel = GPTNeoForCausalLM.from_config(model_config)
vars.layers_module_names = utils.get_layers_module_names(metamodel)
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {}
@ -1910,7 +1926,9 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
for filename in filenames:
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
shutil.rmtree("cache/")
patch_causallm(model.__class__)
if(vars.hascuda):
if(vars.usegpu):
vars.modeldim = get_hidden_size_from_model(model)

View File

@ -50,6 +50,7 @@ import itertools
import zipfile
import pickle
import torch
import utils
from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@ -213,7 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
@contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False):
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
if not enable:
yield False
return
@ -236,24 +237,29 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch.load = torch_load
if dematerialized_modules:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
import accelerate
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
yield True
@ -262,7 +268,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load
if dematerialized_modules:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
init_empty_weights.__exit__(None, None, None)
else:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict

View File

@ -7,10 +7,19 @@ import tempfile
import requests
import requests.adapters
import time
from transformers import __version__ as transformers_version
from transformers import PreTrainedModel
import packaging.version
from tqdm.auto import tqdm
import os
import itertools
from typing import Optional
from typing import List, Optional
HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
try:
import accelerate
except ImportError:
HAS_ACCELERATE = False
vars = None
num_shards: Optional[int] = None
@ -300,3 +309,53 @@ def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename,
import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
#==================================================================#
# Given a PreTrainedModel, returns the list of module names that correspond
# to the model's hidden layers.
#==================================================================#
def get_layers_module_names(model: PreTrainedModel) -> List[str]:
names: List[str] = []
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if c[0].isnumeric() and any(c[1].__class__.__name__.endswith(suffix) for suffix in ("Block", "Layer")):
names.append(name)
else:
recurse(c[1], head=name + ".")
recurse(model)
return names
#==================================================================#
# Given a PreTrainedModel, returns the module name that corresponds
# to the model's input embeddings.
#==================================================================#
def get_input_embeddings_module_name(model: PreTrainedModel) -> str:
embeddings = model.get_input_embeddings()
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if c[1] is embeddings:
return name
else:
return recurse(c[1], head=name + ".")
return recurse(model)
#==================================================================#
# Given a PreTrainedModel and a list of module names, returns a list
# of module names such that the union of the set of modules given as input
# and the set of modules returned as output contains all modules in the model.
#==================================================================#
def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[str]:
missing_names: List[str] = []
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if any(name.startswith(n) for n in names):
continue
if next(c[1].named_children(), None) is None:
missing_names.append(name)
else:
recurse(c[1], head=name + ".")
recurse(model)
return missing_names