prompt_tuner.py always has accelerate
This commit is contained in:
parent
8da6893407
commit
b1c456ec18
|
@ -27,22 +27,13 @@ import torch.nn.functional as F
|
||||||
from torch.nn import Embedding, CrossEntropyLoss
|
from torch.nn import Embedding, CrossEntropyLoss
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import __version__ as transformers_version
|
from transformers import __version__ as transformers_version
|
||||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils, GPTNeoModel, GPTJModel
|
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
|
||||||
import accelerate
|
import accelerate
|
||||||
import accelerate.utils
|
import accelerate.utils
|
||||||
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
||||||
from mkultra.soft_prompt import SoftPrompt
|
from mkultra.soft_prompt import SoftPrompt
|
||||||
from typing import Dict, List, Optional, TextIO, Union
|
from typing import Dict, List, Optional, TextIO, Union
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers import XGLMModel
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
from transformers.models.opt.modeling_opt import OPTDecoder
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import breakmodel
|
import breakmodel
|
||||||
import torch_lazy_loader
|
import torch_lazy_loader
|
||||||
import utils
|
import utils
|
||||||
|
@ -189,7 +180,7 @@ def patch_transformers():
|
||||||
def move_model_to_devices(model, usegpu, gpu_device):
|
def move_model_to_devices(model, usegpu, gpu_device):
|
||||||
global generator
|
global generator
|
||||||
|
|
||||||
if(not utils.HAS_ACCELERATE and not USE_BREAKMODEL):
|
if(not USE_BREAKMODEL):
|
||||||
if(usegpu):
|
if(usegpu):
|
||||||
model = model.half().to(gpu_device)
|
model = model.half().to(gpu_device)
|
||||||
else:
|
else:
|
||||||
|
@ -197,10 +188,6 @@ def move_model_to_devices(model, usegpu, gpu_device):
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
return
|
return
|
||||||
|
|
||||||
import breakmodel
|
|
||||||
|
|
||||||
if(utils.HAS_ACCELERATE):
|
|
||||||
import accelerate.utils
|
|
||||||
for key, value in model.state_dict().items():
|
for key, value in model.state_dict().items():
|
||||||
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
|
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
|
||||||
if(value.dtype is not target_dtype):
|
if(value.dtype is not target_dtype):
|
||||||
|
@ -221,44 +208,6 @@ def move_model_to_devices(model, usegpu, gpu_device):
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
return
|
return
|
||||||
|
|
||||||
model.half()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if(hasattr(model, "transformer")):
|
|
||||||
model.transformer.wte.to(breakmodel.primary_device)
|
|
||||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
|
||||||
if(hasattr(model, 'lm_head')):
|
|
||||||
model.lm_head.to(breakmodel.primary_device)
|
|
||||||
if(hasattr(model.transformer, 'wpe')):
|
|
||||||
model.transformer.wpe.to(breakmodel.primary_device)
|
|
||||||
elif(not hasattr(model.model, "decoder")):
|
|
||||||
model.model.embed_tokens.to(breakmodel.primary_device)
|
|
||||||
model.model.layer_norm.to(breakmodel.primary_device)
|
|
||||||
model.lm_head.to(breakmodel.primary_device)
|
|
||||||
model.model.embed_positions.to(breakmodel.primary_device)
|
|
||||||
else:
|
|
||||||
model.model.decoder.embed_tokens.to(breakmodel.primary_device)
|
|
||||||
if(model.model.decoder.project_in is not None):
|
|
||||||
model.model.decoder.project_in.to(breakmodel.primary_device)
|
|
||||||
if(model.model.decoder.project_out is not None):
|
|
||||||
model.model.decoder.project_out.to(breakmodel.primary_device)
|
|
||||||
model.model.decoder.embed_positions.to(breakmodel.primary_device)
|
|
||||||
gc.collect()
|
|
||||||
GPTNeoModel.forward = breakmodel.new_forward_neo
|
|
||||||
if("GPTJModel" in globals()):
|
|
||||||
GPTJModel.forward = breakmodel.new_forward_neo # type: ignore
|
|
||||||
if("XGLMModel" in globals()):
|
|
||||||
XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore
|
|
||||||
if("OPTDecoder" in globals()):
|
|
||||||
OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore
|
|
||||||
generator = model.generate
|
|
||||||
if(hasattr(model, "transformer")):
|
|
||||||
breakmodel.move_hidden_layers(model.transformer)
|
|
||||||
elif(not hasattr(model.model, "decoder")):
|
|
||||||
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
|
||||||
else:
|
|
||||||
breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers)
|
|
||||||
|
|
||||||
|
|
||||||
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
|
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
|
||||||
|
|
||||||
|
@ -785,7 +734,6 @@ class TrainerBase(abc.ABC):
|
||||||
|
|
||||||
if utils.num_shards is None or utils.current_shard == 0:
|
if utils.num_shards is None or utils.current_shard == 0:
|
||||||
utils.offload_index = {}
|
utils.offload_index = {}
|
||||||
if utils.HAS_ACCELERATE:
|
|
||||||
if os.path.isdir("accelerate-disk-cache"):
|
if os.path.isdir("accelerate-disk-cache"):
|
||||||
# Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder
|
# Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder
|
||||||
# (the folder doesn't contain any subfolders so os.remove will do just fine)
|
# (the folder doesn't contain any subfolders so os.remove will do just fine)
|
||||||
|
@ -835,7 +783,7 @@ class TrainerBase(abc.ABC):
|
||||||
model_dict[key] = model_dict[key].to(torch.float32)
|
model_dict[key] = model_dict[key].to(torch.float32)
|
||||||
if device == "shared":
|
if device == "shared":
|
||||||
model_dict[key] = model_dict[key].to("cpu").detach_()
|
model_dict[key] = model_dict[key].to("cpu").detach_()
|
||||||
if able_to_pin_layers and utils.HAS_ACCELERATE:
|
if able_to_pin_layers:
|
||||||
try:
|
try:
|
||||||
model_dict[key] = model_dict[key].pin_memory()
|
model_dict[key] = model_dict[key].pin_memory()
|
||||||
except:
|
except:
|
||||||
|
|
Loading…
Reference in New Issue