Merge pull request #155 from VE-FORBRYDERNE/accelerate

Initial support for Accelerate
This commit is contained in:
henk717 2022-06-20 01:08:54 +02:00 committed by GitHub
commit efed44ac8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 168 additions and 82 deletions

View File

@ -610,6 +610,24 @@ def move_model_to_devices(model):
model.half() model.half()
gc.collect() gc.collect()
if(utils.HAS_ACCELERATE):
import accelerate
gpu_blocks = breakmodel.gpu_blocks
ram_blocks = len(vars.layers_module_names) - sum(gpu_blocks)
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
device_map = {}
for name in vars.layers_module_names:
layer = int(name.rsplit(".", 1)[1])
device = "cpu" if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[name] = device
for name in utils.get_missing_module_names(model, list(device_map.keys())):
device_map[name] = breakmodel.primary_device
accelerate.dispatch_model(model, device_map, main_device=breakmodel.primary_device)
gc.collect()
generator = model.generate
return
if(hasattr(model, "transformer")): if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device) model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device) model.transformer.ln_f.to(breakmodel.primary_device)
@ -1192,8 +1210,37 @@ def get_oai_models(key):
print("{0}ERROR!{1}".format(colors.RED, colors.END)) print("{0}ERROR!{1}".format(colors.RED, colors.END))
print(req.json()) print(req.json())
emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
# Function to patch transformers to use our soft prompt
def patch_causallm(cls):
if(getattr(cls, "_koboldai_patch_causallm_patched", False)):
return
old_forward = cls.forward
def new_causallm_forward(self, *args, **kwargs):
input_ids = kwargs.get('input_ids').to(self.device)
assert input_ids is not None
kwargs['input_ids'] = None
if(vars.sp is not None):
shifted_input_ids = input_ids - self.config.vocab_size
input_ids.clamp_(max=self.config.vocab_size-1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
inputs_embeds *= self.model.embed_scale
kwargs['inputs_embeds'] = inputs_embeds
return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward
cls._koboldai_patch_causallm_patched = True
return cls
def patch_transformers(): def patch_transformers():
global transformers global transformers
old_from_pretrained = PreTrainedModel.from_pretrained.__func__ old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@ -1241,42 +1288,6 @@ def patch_transformers():
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
XGLMSinusoidalPositionalEmbedding.forward = new_forward XGLMSinusoidalPositionalEmbedding.forward = new_forward
# Patch transformers to use our soft prompt
def patch_causallm(cls):
old_forward = cls.forward
def new_causallm_forward(self, *args, **kwargs):
input_ids = kwargs.get('input_ids').to(self.device)
assert input_ids is not None
kwargs['input_ids'] = None
if(vars.sp is not None):
shifted_input_ids = input_ids - self.config.vocab_size
input_ids.clamp_(max=self.config.vocab_size-1)
if(hasattr(self, "transformer")):
inputs_embeds = self.transformer.wte(input_ids)
elif(not hasattr(self.model, "decoder")):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
inputs_embeds = self.model.decoder.embed_tokens(input_ids)
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
inputs_embeds *= self.model.embed_scale
kwargs['inputs_embeds'] = inputs_embeds
return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
patch_causallm(cls)
for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"):
try:
patch_causallm(getattr(__import__("transformers"), c))
except:
pass
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size # Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")): if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")):
@ -1563,7 +1574,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
loadsettings() loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
vars.hascuda = torch.cuda.is_available() vars.hascuda = torch.cuda.is_available()
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm", "opt") and not vars.nobreakmodel vars.bmsupported = (utils.HAS_ACCELERATE or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel
if(args.breakmodel is not None and args.breakmodel): if(args.breakmodel is not None and args.breakmodel):
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr) print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
if(args.breakmodel_layers is not None): if(args.breakmodel_layers is not None):
@ -1657,24 +1668,20 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
else: else:
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
def lazy_load_callback(model_dict, f, **_): def lazy_load_callback(model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]], f, **_):
if lazy_load_callback.nested: if lazy_load_callback.nested:
return return
lazy_load_callback.nested = True lazy_load_callback.nested = True
device_map = {} device_map: Dict[str, Union[str, int]] = {}
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
for layer in range(n_layers):
key = _key.format(layer=layer)
if key not in model_dict:
continue
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[key] = device
for key, value in model_dict.items(): for key, value in model_dict.items():
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map: if isinstance(value, torch_lazy_loader.LazyTensor) and not any(key.startswith(n) or key.startswith(n.split(".", 1)[1]) for n in vars.layers_module_names):
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel else breakmodel.primary_device
else:
layer = int(max((n for n in vars.layers_module_names if key.startswith(n) or key.startswith(n.split(".", 1)[1])), key=len).rsplit(".", 1)[1])
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel else "shared" if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
device_map[key] = device
if utils.num_shards is None or utils.current_shard == 0: if utils.num_shards is None or utils.current_shard == 0:
if utils.num_shards is not None: if utils.num_shards is not None:
@ -1689,6 +1696,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
last_storage_key = None last_storage_key = None
f = None f = None
current_offset = 0 current_offset = 0
able_to_pin_layers = True
if utils.num_shards is not None: if utils.num_shards is not None:
utils.current_shard += 1 utils.current_shard += 1
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
@ -1714,7 +1722,15 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
model_dict[key] = model_dict[key].to(torch.float16) model_dict[key] = model_dict[key].to(torch.float16)
if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16: if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16:
model_dict[key] = model_dict[key].to(torch.float32) model_dict[key] = model_dict[key].to(torch.float32)
model_dict[key] = model_dict[key].to(device) if device == "shared":
model_dict[key] = model_dict[key].to("cpu").detach_()
if able_to_pin_layers and utils.HAS_ACCELERATE:
try:
model_dict[key] = model_dict[key].pin_memory()
except:
able_to_pin_layers = False
else:
model_dict[key] = model_dict[key].to(device)
#print("OK", flush=True) #print("OK", flush=True)
current_offset += nbytes current_offset += nbytes
utils.bar.update(1) utils.bar.update(1)
@ -1729,15 +1745,6 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
lazy_load_callback.nested = False lazy_load_callback.nested = False
return lazy_load_callback return lazy_load_callback
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)):
with open(lazy_load_config_path) as f:
lazy_load_spec = json.load(f)
else:
vars.lazy_load = False
def get_hidden_size_from_model(model): def get_hidden_size_from_model(model):
try: try:
@ -1791,6 +1798,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
else: else:
model = model.to('cpu').float() model = model.to('cpu').float()
generator = model.generate generator = model.generate
patch_causallm(model.__class__)
# Use the Generic implementation # Use the Generic implementation
else: else:
lowmem = maybe_low_cpu_mem_usage() lowmem = maybe_low_cpu_mem_usage()
@ -1799,6 +1807,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
# feature yet # feature yet
if(vars.model_type == "gpt2"): if(vars.model_type == "gpt2"):
lowmem = {} lowmem = {}
vars.lazy_load = False # Also, lazy loader doesn't support GPT-2 models
# If we're using torch_lazy_loader, we need to get breakmodel config # If we're using torch_lazy_loader, we need to get breakmodel config
# early so that it knows where to load the individual model tensors # early so that it knows where to load the individual model tensors
@ -1812,6 +1821,13 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
import shutil import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True) print("\n", flush=True)
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
try:
metamodel = AutoModelForCausalLM.from_config(model_config)
except Exception as e:
metamodel = GPTNeoForCausalLM.from_config(model_config)
vars.layers_module_names = utils.get_layers_module_names(metamodel)
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True): with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {} lowmem = {}
@ -1910,7 +1926,9 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
for filename in filenames: for filename in filenames:
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename)) shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
shutil.rmtree("cache/") shutil.rmtree("cache/")
patch_causallm(model.__class__)
if(vars.hascuda): if(vars.hascuda):
if(vars.usegpu): if(vars.usegpu):
vars.modeldim = get_hidden_size_from_model(model) vars.modeldim = get_hidden_size_from_model(model)

View File

@ -50,6 +50,7 @@ import itertools
import zipfile import zipfile
import pickle import pickle
import torch import torch
import utils
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@ -213,7 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
@contextlib.contextmanager @contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False): def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
if not enable: if not enable:
yield False yield False
return return
@ -236,24 +237,29 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch.load = torch_load torch.load = torch_load
if dematerialized_modules: if dematerialized_modules:
old_linear_init = torch.nn.Linear.__init__ if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
old_embedding_init = torch.nn.Embedding.__init__ import accelerate
old_layernorm_init = torch.nn.LayerNorm.__init__ init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
def linear_init(self, *args, device=None, **kwargs): def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs) return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs): def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs) return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs): def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs) return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict torch.nn.Module._load_from_state_dict = _load_from_state_dict
yield True yield True
@ -262,7 +268,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch._utils._rebuild_tensor = old_rebuild_tensor torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load torch.load = old_torch_load
if dematerialized_modules: if dematerialized_modules:
torch.nn.Linear.__init__ = old_linear_init if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
torch.nn.Embedding.__init__ = old_embedding_init init_empty_weights.__exit__(None, None, None)
torch.nn.LayerNorm.__init__ = old_layernorm_init else:
torch.nn.Module._load_from_state_dict = old_load_from_state_dict torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict

View File

@ -7,10 +7,19 @@ import tempfile
import requests import requests
import requests.adapters import requests.adapters
import time import time
from transformers import __version__ as transformers_version
from transformers import PreTrainedModel
import packaging.version
from tqdm.auto import tqdm from tqdm.auto import tqdm
import os import os
import itertools import itertools
from typing import Optional from typing import List, Optional
HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
try:
import accelerate
except ImportError:
HAS_ACCELERATE = False
vars = None vars = None
num_shards: Optional[int] = None num_shards: Optional[int] = None
@ -300,3 +309,53 @@ def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename,
import torch import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror) shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths))) return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
#==================================================================#
# Given a PreTrainedModel, returns the list of module names that correspond
# to the model's hidden layers.
#==================================================================#
def get_layers_module_names(model: PreTrainedModel) -> List[str]:
names: List[str] = []
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if c[0].isnumeric() and any(c[1].__class__.__name__.endswith(suffix) for suffix in ("Block", "Layer")):
names.append(name)
else:
recurse(c[1], head=name + ".")
recurse(model)
return names
#==================================================================#
# Given a PreTrainedModel, returns the module name that corresponds
# to the model's input embeddings.
#==================================================================#
def get_input_embeddings_module_name(model: PreTrainedModel) -> str:
embeddings = model.get_input_embeddings()
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if c[1] is embeddings:
return name
else:
return recurse(c[1], head=name + ".")
return recurse(model)
#==================================================================#
# Given a PreTrainedModel and a list of module names, returns a list
# of module names such that the union of the set of modules given as input
# and the set of modules returned as output contains all modules in the model.
#==================================================================#
def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[str]:
missing_names: List[str] = []
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if any(name.startswith(n) for n in names):
continue
if next(c[1].named_children(), None) is None:
missing_names.append(name)
else:
recurse(c[1], head=name + ".")
recurse(model)
return missing_names