diff --git a/prompt_tuner.py b/prompt_tuner.py index 6d0c907e..c13a8a53 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -27,22 +27,13 @@ import torch.nn.functional as F from torch.nn import Embedding, CrossEntropyLoss import transformers from transformers import __version__ as transformers_version -from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils, GPTNeoModel, GPTJModel +from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils import accelerate import accelerate.utils from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM from mkultra.soft_prompt import SoftPrompt from typing import Dict, List, Optional, TextIO, Union -try: - from transformers import XGLMModel -except: - pass -try: - from transformers.models.opt.modeling_opt import OPTDecoder -except: - pass - import breakmodel import torch_lazy_loader import utils @@ -189,7 +180,7 @@ def patch_transformers(): def move_model_to_devices(model, usegpu, gpu_device): global generator - if(not utils.HAS_ACCELERATE and not USE_BREAKMODEL): + if(not USE_BREAKMODEL): if(usegpu): model = model.half().to(gpu_device) else: @@ -197,67 +188,25 @@ def move_model_to_devices(model, usegpu, gpu_device): generator = model.generate return - import breakmodel - - if(utils.HAS_ACCELERATE): - import accelerate.utils - for key, value in model.state_dict().items(): - target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16 - if(value.dtype is not target_dtype): - accelerate.utils.set_module_tensor_to_device(model, key, target_dtype) - disk_blocks = breakmodel.disk_blocks - gpu_blocks = breakmodel.gpu_blocks - ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks) - cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) - device_map = {} - for name in utils.layers_module_names: - layer = int(name.rsplit(".", 1)[1]) - device = ("disk" if layer < disk_blocks else "cpu") if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks) - device_map[name] = device - for name in utils.get_missing_module_names(model, list(device_map.keys())): - device_map[name] = breakmodel.primary_device - breakmodel.dispatch_model_ex(model, device_map, main_device=breakmodel.primary_device, offload_buffers=True, offload_dir="accelerate-disk-cache") - gc.collect() - generator = model.generate - return - - model.half() + for key, value in model.state_dict().items(): + target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16 + if(value.dtype is not target_dtype): + accelerate.utils.set_module_tensor_to_device(model, key, target_dtype) + disk_blocks = breakmodel.disk_blocks + gpu_blocks = breakmodel.gpu_blocks + ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) + device_map = {} + for name in utils.layers_module_names: + layer = int(name.rsplit(".", 1)[1]) + device = ("disk" if layer < disk_blocks else "cpu") if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks) + device_map[name] = device + for name in utils.get_missing_module_names(model, list(device_map.keys())): + device_map[name] = breakmodel.primary_device + breakmodel.dispatch_model_ex(model, device_map, main_device=breakmodel.primary_device, offload_buffers=True, offload_dir="accelerate-disk-cache") gc.collect() - - if(hasattr(model, "transformer")): - model.transformer.wte.to(breakmodel.primary_device) - model.transformer.ln_f.to(breakmodel.primary_device) - if(hasattr(model, 'lm_head')): - model.lm_head.to(breakmodel.primary_device) - if(hasattr(model.transformer, 'wpe')): - model.transformer.wpe.to(breakmodel.primary_device) - elif(not hasattr(model.model, "decoder")): - model.model.embed_tokens.to(breakmodel.primary_device) - model.model.layer_norm.to(breakmodel.primary_device) - model.lm_head.to(breakmodel.primary_device) - model.model.embed_positions.to(breakmodel.primary_device) - else: - model.model.decoder.embed_tokens.to(breakmodel.primary_device) - if(model.model.decoder.project_in is not None): - model.model.decoder.project_in.to(breakmodel.primary_device) - if(model.model.decoder.project_out is not None): - model.model.decoder.project_out.to(breakmodel.primary_device) - model.model.decoder.embed_positions.to(breakmodel.primary_device) - gc.collect() - GPTNeoModel.forward = breakmodel.new_forward_neo - if("GPTJModel" in globals()): - GPTJModel.forward = breakmodel.new_forward_neo # type: ignore - if("XGLMModel" in globals()): - XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore - if("OPTDecoder" in globals()): - OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore generator = model.generate - if(hasattr(model, "transformer")): - breakmodel.move_hidden_layers(model.transformer) - elif(not hasattr(model.model, "decoder")): - breakmodel.move_hidden_layers(model.model, model.model.layers) - else: - breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers) + return _PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel] @@ -785,16 +734,15 @@ class TrainerBase(abc.ABC): if utils.num_shards is None or utils.current_shard == 0: utils.offload_index = {} - if utils.HAS_ACCELERATE: - if os.path.isdir("accelerate-disk-cache"): - # Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder - # (the folder doesn't contain any subfolders so os.remove will do just fine) - for filename in os.listdir("accelerate-disk-cache"): - try: - os.remove(os.path.join("accelerate-disk-cache", filename)) - except OSError: - pass - os.makedirs("accelerate-disk-cache", exist_ok=True) + if os.path.isdir("accelerate-disk-cache"): + # Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder + # (the folder doesn't contain any subfolders so os.remove will do just fine) + for filename in os.listdir("accelerate-disk-cache"): + try: + os.remove(os.path.join("accelerate-disk-cache", filename)) + except OSError: + pass + os.makedirs("accelerate-disk-cache", exist_ok=True) if utils.num_shards is not None: num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) else: @@ -835,7 +783,7 @@ class TrainerBase(abc.ABC): model_dict[key] = model_dict[key].to(torch.float32) if device == "shared": model_dict[key] = model_dict[key].to("cpu").detach_() - if able_to_pin_layers and utils.HAS_ACCELERATE: + if able_to_pin_layers: try: model_dict[key] = model_dict[key].pin_memory() except: