prompt_tuner.py always has accelerate

This commit is contained in:
vfbd 2022-08-22 19:52:47 -04:00
parent 8da6893407
commit b1c456ec18
1 changed files with 29 additions and 81 deletions

View File

@ -27,22 +27,13 @@ import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss from torch.nn import Embedding, CrossEntropyLoss
import transformers import transformers
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils, GPTNeoModel, GPTJModel from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
import accelerate import accelerate
import accelerate.utils import accelerate.utils
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
from mkultra.soft_prompt import SoftPrompt from mkultra.soft_prompt import SoftPrompt
from typing import Dict, List, Optional, TextIO, Union from typing import Dict, List, Optional, TextIO, Union
try:
from transformers import XGLMModel
except:
pass
try:
from transformers.models.opt.modeling_opt import OPTDecoder
except:
pass
import breakmodel import breakmodel
import torch_lazy_loader import torch_lazy_loader
import utils import utils
@ -189,7 +180,7 @@ def patch_transformers():
def move_model_to_devices(model, usegpu, gpu_device): def move_model_to_devices(model, usegpu, gpu_device):
global generator global generator
if(not utils.HAS_ACCELERATE and not USE_BREAKMODEL): if(not USE_BREAKMODEL):
if(usegpu): if(usegpu):
model = model.half().to(gpu_device) model = model.half().to(gpu_device)
else: else:
@ -197,10 +188,6 @@ def move_model_to_devices(model, usegpu, gpu_device):
generator = model.generate generator = model.generate
return return
import breakmodel
if(utils.HAS_ACCELERATE):
import accelerate.utils
for key, value in model.state_dict().items(): for key, value in model.state_dict().items():
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16 target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
if(value.dtype is not target_dtype): if(value.dtype is not target_dtype):
@ -221,44 +208,6 @@ def move_model_to_devices(model, usegpu, gpu_device):
generator = model.generate generator = model.generate
return return
model.half()
gc.collect()
if(hasattr(model, "transformer")):
model.transformer.wte.to(breakmodel.primary_device)
model.transformer.ln_f.to(breakmodel.primary_device)
if(hasattr(model, 'lm_head')):
model.lm_head.to(breakmodel.primary_device)
if(hasattr(model.transformer, 'wpe')):
model.transformer.wpe.to(breakmodel.primary_device)
elif(not hasattr(model.model, "decoder")):
model.model.embed_tokens.to(breakmodel.primary_device)
model.model.layer_norm.to(breakmodel.primary_device)
model.lm_head.to(breakmodel.primary_device)
model.model.embed_positions.to(breakmodel.primary_device)
else:
model.model.decoder.embed_tokens.to(breakmodel.primary_device)
if(model.model.decoder.project_in is not None):
model.model.decoder.project_in.to(breakmodel.primary_device)
if(model.model.decoder.project_out is not None):
model.model.decoder.project_out.to(breakmodel.primary_device)
model.model.decoder.embed_positions.to(breakmodel.primary_device)
gc.collect()
GPTNeoModel.forward = breakmodel.new_forward_neo
if("GPTJModel" in globals()):
GPTJModel.forward = breakmodel.new_forward_neo # type: ignore
if("XGLMModel" in globals()):
XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore
if("OPTDecoder" in globals()):
OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore
generator = model.generate
if(hasattr(model, "transformer")):
breakmodel.move_hidden_layers(model.transformer)
elif(not hasattr(model.model, "decoder")):
breakmodel.move_hidden_layers(model.model, model.model.layers)
else:
breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers)
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel] _PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
@ -785,7 +734,6 @@ class TrainerBase(abc.ABC):
if utils.num_shards is None or utils.current_shard == 0: if utils.num_shards is None or utils.current_shard == 0:
utils.offload_index = {} utils.offload_index = {}
if utils.HAS_ACCELERATE:
if os.path.isdir("accelerate-disk-cache"): if os.path.isdir("accelerate-disk-cache"):
# Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder # Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder
# (the folder doesn't contain any subfolders so os.remove will do just fine) # (the folder doesn't contain any subfolders so os.remove will do just fine)
@ -835,7 +783,7 @@ class TrainerBase(abc.ABC):
model_dict[key] = model_dict[key].to(torch.float32) model_dict[key] = model_dict[key].to(torch.float32)
if device == "shared": if device == "shared":
model_dict[key] = model_dict[key].to("cpu").detach_() model_dict[key] = model_dict[key].to("cpu").detach_()
if able_to_pin_layers and utils.HAS_ACCELERATE: if able_to_pin_layers:
try: try:
model_dict[key] = model_dict[key].pin_memory() model_dict[key] = model_dict[key].pin_memory()
except: except: