From f326fc07e8da6d64b63faac82dd91c101b6bd27b Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 31 May 2023 14:42:05 -0500 Subject: [PATCH] Seems to work --- modeling/inference_models/hf_torch.py | 38 +----------- modeling/patches.py | 87 +-------------------------- 2 files changed, 4 insertions(+), 121 deletions(-) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 151857cf..1fb78717 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -278,56 +278,22 @@ class HFTorchInferenceModel(HFInferenceModel): # Try to determine model type from either AutoModel or falling back to legacy try: - # with accelerate.init_empty_weights(): - # model = AutoModelForCausalLM.from_config(self.model_config) - - # print("[HUGE SKELETON] MAKING DEVICE MAP") - # device_map = infer_auto_device_map( - # model, - # no_split_module_classes=model._no_split_modules, - # max_memory={0: "10GiB", 1: "7GiB", "cpu": "20GiB"}, - # dtype=torch.float16, - # ) - - # # TODO: ?? - # print("[HUGE SKELETON] TYING WEIGHTS") - # model.tie_weights() - print("[HUGE SKELETON] LOADING FROM PRETRAINED") - # model = load_checkpoint_and_dispatch( - # model, - # location + "/pytorch_model.bin", - # device_map=device_map, - # no_split_module_classes=model._no_split_modules, - # dtype=torch.float16, - # ) with lazy_loader.use_lazy_load( enable=True, + # DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!! # dematerialized_modules=True, dematerialized_modules=False, ): model = AutoModelForCausalLM.from_pretrained( location, device_map="auto", - max_memory={0: "10GiB", 1: "7GiB", "cpu": "20GiB"}, + # max_memory={0: "10GiB", 1: "7GiB", "cpu": "20GiB"}, offload_folder="accelerate-disk-cache", torch_dtype=torch.float16, **tf_kwargs, ) - for name, value in list(model.named_parameters()) + list( - model.named_buffers() - ): - if value.device != torch.device("meta"): - continue - print(name, value, value.nelement()) - # try: - # value.cpu() - # except NotImplementedError: - # # Can't be copied out of meta tensor, no data - # print("Bad news at", name) - # # setattr(model, name, torch.zeros(value.size())) - return model except Exception as e: traceback_string = traceback.format_exc().lower() diff --git a/modeling/patches.py b/modeling/patches.py index 7393b6ba..bd5a6325 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -148,29 +148,6 @@ def patch_transformers_for_lazyload() -> None: limitations under the License. """ import torch - import accelerate - - # _old_set_module_tensor_to_device = ( - # accelerate.utils.modeling.set_module_tensor_to_device - # ) - - # def _set_module_tensor_to_device( - # module: torch.nn.Module, - # tensor_name: str, - # device: Union[int, str, torch.device], - # value: Optional[torch.Tensor] = None, - # dtype: Optional[Union[str, torch.dtype]] = None, - # ): - # if isinstance(value, LazyTensor): - # value = value.materialize() - # print("HEY!", dtype) - # return _old_set_module_tensor_to_device( - # module, tensor_name, device, value, dtype - # ) - - # accelerate.utils.modeling.set_module_tensor_to_device = _set_module_tensor_to_device - - from accelerate.utils.modeling import named_module_tensors from accelerate.utils import set_module_tensor_to_device, offload_weight def _load_state_dict_into_meta_model( @@ -225,10 +202,7 @@ def patch_transformers_for_lazyload() -> None: for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) - # BEGIN PATCH - # TODO: Based on config - # dtype = torch.float16 - +# BEGIN PATCH for param_name, param in sorted( state_dict.items(), # State dict must be ordered in this manner to make the caching in @@ -243,7 +217,7 @@ def patch_transformers_for_lazyload() -> None: if isinstance(param, LazyTensor): # Should always be true param = param.materialize() - # END PATCH +# END PATCH # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if ( @@ -338,63 +312,6 @@ def patch_transformers_for_lazyload() -> None: _load_state_dict_into_meta_model ) - # # Patch AlignDevicesHook to hack around OPT lm_head - # HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"] - - # def _init_hook(self, module): - # if not self.offload and self.execution_device is not None: - # # BEGIN PATCH - # for name, tensor in named_module_tensors( - # module, recurse=self.place_submodules - # ): - # try: - # set_module_tensor_to_device(module, name, self.execution_device) - # except ValueError: - # # ValueError: weight is on the meta device, we need a `value` to put in on 0. - # # bleuuuuuuuuuuuuuuuhhh - # if name in HACK_ZERO_ON_FAIL_TENSORS: - # logger.warning( - # f"Couldn't find value for weight {name}, zeroing." - # ) - # set_module_tensor_to_device( - # module, - # name, - # self.execution_device, - # value=torch.zeros(tensor.shape), - # ) - # # END PATCH - # elif self.offload: - # self.original_devices = { - # name: param.device - # for name, param in named_module_tensors( - # module, recurse=self.place_submodules - # ) - # } - - # if self.weights_map is None: - # self.weights_map = { - # name: param.to("cpu") - # for name, param in named_module_tensors( - # module, - # include_buffers=self.offload_buffers, - # recurse=self.place_submodules, - # ) - # } - - # for name, _ in named_module_tensors( - # module, - # include_buffers=self.offload_buffers, - # recurse=self.place_submodules, - # ): - # set_module_tensor_to_device(module, name, "meta") - - # if not self.offload_buffers and self.execution_device is not None: - # for name, _ in module.named_buffers(recurse=self.place_submodules): - # set_module_tensor_to_device(module, name, self.execution_device) - # return module - - # accelerate.hooks.AlignDevicesHook.init_hook = _init_hook - def patch_transformers() -> None: patch_transformers_download()