prompt_tuner.py always has accelerate
This commit is contained in:
parent
8da6893407
commit
b1c456ec18
110
prompt_tuner.py
110
prompt_tuner.py
|
@ -27,22 +27,13 @@ import torch.nn.functional as F
|
|||
from torch.nn import Embedding, CrossEntropyLoss
|
||||
import transformers
|
||||
from transformers import __version__ as transformers_version
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils, GPTNeoModel, GPTJModel
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
|
||||
import accelerate
|
||||
import accelerate.utils
|
||||
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
||||
from mkultra.soft_prompt import SoftPrompt
|
||||
from typing import Dict, List, Optional, TextIO, Union
|
||||
|
||||
try:
|
||||
from transformers import XGLMModel
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
from transformers.models.opt.modeling_opt import OPTDecoder
|
||||
except:
|
||||
pass
|
||||
|
||||
import breakmodel
|
||||
import torch_lazy_loader
|
||||
import utils
|
||||
|
@ -189,7 +180,7 @@ def patch_transformers():
|
|||
def move_model_to_devices(model, usegpu, gpu_device):
|
||||
global generator
|
||||
|
||||
if(not utils.HAS_ACCELERATE and not USE_BREAKMODEL):
|
||||
if(not USE_BREAKMODEL):
|
||||
if(usegpu):
|
||||
model = model.half().to(gpu_device)
|
||||
else:
|
||||
|
@ -197,67 +188,25 @@ def move_model_to_devices(model, usegpu, gpu_device):
|
|||
generator = model.generate
|
||||
return
|
||||
|
||||
import breakmodel
|
||||
|
||||
if(utils.HAS_ACCELERATE):
|
||||
import accelerate.utils
|
||||
for key, value in model.state_dict().items():
|
||||
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
|
||||
if(value.dtype is not target_dtype):
|
||||
accelerate.utils.set_module_tensor_to_device(model, key, target_dtype)
|
||||
disk_blocks = breakmodel.disk_blocks
|
||||
gpu_blocks = breakmodel.gpu_blocks
|
||||
ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks)
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
device_map = {}
|
||||
for name in utils.layers_module_names:
|
||||
layer = int(name.rsplit(".", 1)[1])
|
||||
device = ("disk" if layer < disk_blocks else "cpu") if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
|
||||
device_map[name] = device
|
||||
for name in utils.get_missing_module_names(model, list(device_map.keys())):
|
||||
device_map[name] = breakmodel.primary_device
|
||||
breakmodel.dispatch_model_ex(model, device_map, main_device=breakmodel.primary_device, offload_buffers=True, offload_dir="accelerate-disk-cache")
|
||||
gc.collect()
|
||||
generator = model.generate
|
||||
return
|
||||
|
||||
model.half()
|
||||
for key, value in model.state_dict().items():
|
||||
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
|
||||
if(value.dtype is not target_dtype):
|
||||
accelerate.utils.set_module_tensor_to_device(model, key, target_dtype)
|
||||
disk_blocks = breakmodel.disk_blocks
|
||||
gpu_blocks = breakmodel.gpu_blocks
|
||||
ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks)
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
device_map = {}
|
||||
for name in utils.layers_module_names:
|
||||
layer = int(name.rsplit(".", 1)[1])
|
||||
device = ("disk" if layer < disk_blocks else "cpu") if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
|
||||
device_map[name] = device
|
||||
for name in utils.get_missing_module_names(model, list(device_map.keys())):
|
||||
device_map[name] = breakmodel.primary_device
|
||||
breakmodel.dispatch_model_ex(model, device_map, main_device=breakmodel.primary_device, offload_buffers=True, offload_dir="accelerate-disk-cache")
|
||||
gc.collect()
|
||||
|
||||
if(hasattr(model, "transformer")):
|
||||
model.transformer.wte.to(breakmodel.primary_device)
|
||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||
if(hasattr(model, 'lm_head')):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(hasattr(model.transformer, 'wpe')):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
elif(not hasattr(model.model, "decoder")):
|
||||
model.model.embed_tokens.to(breakmodel.primary_device)
|
||||
model.model.layer_norm.to(breakmodel.primary_device)
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
model.model.embed_positions.to(breakmodel.primary_device)
|
||||
else:
|
||||
model.model.decoder.embed_tokens.to(breakmodel.primary_device)
|
||||
if(model.model.decoder.project_in is not None):
|
||||
model.model.decoder.project_in.to(breakmodel.primary_device)
|
||||
if(model.model.decoder.project_out is not None):
|
||||
model.model.decoder.project_out.to(breakmodel.primary_device)
|
||||
model.model.decoder.embed_positions.to(breakmodel.primary_device)
|
||||
gc.collect()
|
||||
GPTNeoModel.forward = breakmodel.new_forward_neo
|
||||
if("GPTJModel" in globals()):
|
||||
GPTJModel.forward = breakmodel.new_forward_neo # type: ignore
|
||||
if("XGLMModel" in globals()):
|
||||
XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore
|
||||
if("OPTDecoder" in globals()):
|
||||
OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore
|
||||
generator = model.generate
|
||||
if(hasattr(model, "transformer")):
|
||||
breakmodel.move_hidden_layers(model.transformer)
|
||||
elif(not hasattr(model.model, "decoder")):
|
||||
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
||||
else:
|
||||
breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers)
|
||||
return
|
||||
|
||||
|
||||
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
|
||||
|
@ -785,16 +734,15 @@ class TrainerBase(abc.ABC):
|
|||
|
||||
if utils.num_shards is None or utils.current_shard == 0:
|
||||
utils.offload_index = {}
|
||||
if utils.HAS_ACCELERATE:
|
||||
if os.path.isdir("accelerate-disk-cache"):
|
||||
# Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder
|
||||
# (the folder doesn't contain any subfolders so os.remove will do just fine)
|
||||
for filename in os.listdir("accelerate-disk-cache"):
|
||||
try:
|
||||
os.remove(os.path.join("accelerate-disk-cache", filename))
|
||||
except OSError:
|
||||
pass
|
||||
os.makedirs("accelerate-disk-cache", exist_ok=True)
|
||||
if os.path.isdir("accelerate-disk-cache"):
|
||||
# Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder
|
||||
# (the folder doesn't contain any subfolders so os.remove will do just fine)
|
||||
for filename in os.listdir("accelerate-disk-cache"):
|
||||
try:
|
||||
os.remove(os.path.join("accelerate-disk-cache", filename))
|
||||
except OSError:
|
||||
pass
|
||||
os.makedirs("accelerate-disk-cache", exist_ok=True)
|
||||
if utils.num_shards is not None:
|
||||
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||
else:
|
||||
|
@ -835,7 +783,7 @@ class TrainerBase(abc.ABC):
|
|||
model_dict[key] = model_dict[key].to(torch.float32)
|
||||
if device == "shared":
|
||||
model_dict[key] = model_dict[key].to("cpu").detach_()
|
||||
if able_to_pin_layers and utils.HAS_ACCELERATE:
|
||||
if able_to_pin_layers:
|
||||
try:
|
||||
model_dict[key] = model_dict[key].pin_memory()
|
||||
except:
|
||||
|
|
Loading…
Reference in New Issue