From 6f93150e4d85057d77464f93c6a87a43f68cc35d Mon Sep 17 00:00:00 2001 From: somebody Date: Sun, 28 May 2023 12:25:31 -0500 Subject: [PATCH] Work on lazyload --- .../generic_hf_torch/class.py | 46 ++-- modeling/lazy_loader.py | 219 +++++------------- modeling/patches.py | 2 +- prompt_tuner.py | 2 +- 4 files changed, 97 insertions(+), 172 deletions(-) diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index 96d5f0c4..539d2018 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -18,11 +18,11 @@ from modeling.inference_models.hf_torch import HFTorchInferenceModel model_backend_name = "Huggingface" + class model_backend(HFTorchInferenceModel): - def _initialize_model(self): return - + def _load(self, save_model: bool, initial_load: bool) -> None: utils.koboldai_vars.allowsp = True @@ -34,9 +34,7 @@ class model_backend(HFTorchInferenceModel): # utils.koboldai_vars.custmodpth = utils.koboldai_vars.model if self.model_name == "NeoCustom": - self.model_name = os.path.basename( - os.path.normpath(self.path) - ) + self.model_name = os.path.basename(os.path.normpath(self.path)) utils.koboldai_vars.model = self.model_name # If we specify a model and it's in the root directory, we need to move @@ -63,13 +61,18 @@ class model_backend(HFTorchInferenceModel): # If we're using torch_lazy_loader, we need to get breakmodel config # early so that it knows where to load the individual model tensors - logger.debug("lazy_load: {} hascuda: {} breakmodel: {} nobreakmode: {}".format(self.lazy_load, utils.koboldai_vars.hascuda, self.breakmodel, self.nobreakmodel)) + logger.debug( + "lazy_load: {} hascuda: {} breakmodel: {} nobreakmode: {}".format( + self.lazy_load, + utils.koboldai_vars.hascuda, + self.breakmodel, + self.nobreakmodel, + ) + ) if self.lazy_load: # If we're using lazy loader, we need to figure out what the model's hidden layers are called - with lazy_loader.use_lazy_load( - dematerialized_modules=True, use_accelerate_init_empty_weights=True - ): + with lazy_loader.use_lazy_load(dematerialized_modules=True): try: metamodel = AutoModelForCausalLM.from_config(self.model_config) utils.layers_module_names = utils.get_layers_module_names(metamodel) @@ -195,7 +198,9 @@ class model_backend(HFTorchInferenceModel): pass if not any_success: - raise RuntimeError(f"Couldn't find any of {possible_checkpoint_names} in cache for {self.model_name} @ '{utils.koboldai_vars.revisison}'") + raise RuntimeError( + f"Couldn't find any of {possible_checkpoint_names} in cache for {self.model_name} @ '{utils.koboldai_vars.revisison}'" + ) else: # Handle saving sharded models @@ -233,11 +238,24 @@ class model_backend(HFTorchInferenceModel): shutil.rmtree("cache/") self.patch_embedding() - - + self.model.kai_model = self utils.koboldai_vars.modeldim = self.get_hidden_size() def _save_settings(self): - with open("settings/{}.generic_hf_torch.model_backend.settings".format(self.model_name.replace("/", "_")), "w") as f: - json.dump({"layers": self.layers if 'layers' in vars(self) else [], "disk_layers": self.disk_layers if 'disk_layers' in vars(self) else 0}, f, indent="") \ No newline at end of file + with open( + "settings/{}.generic_hf_torch.model_backend.settings".format( + self.model_name.replace("/", "_") + ), + "w", + ) as f: + json.dump( + { + "layers": self.layers if "layers" in vars(self) else [], + "disk_layers": self.disk_layers + if "disk_layers" in vars(self) + else 0, + }, + f, + indent="", + ) diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index a0f67d4a..85ed495d 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -61,6 +61,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union # support it yet. try: import safetensors + HAS_SAFETENSORS = True except ModuleNotFoundError: HAS_SAFETENSORS = False @@ -68,9 +69,6 @@ except ModuleNotFoundError: import utils -_EXTRA_STATE_KEY_SUFFIX = "_extra_state" - - STORAGE_TYPE_MAP = { torch.float64: torch.DoubleStorage, torch.float32: torch.FloatStorage, @@ -84,6 +82,8 @@ STORAGE_TYPE_MAP = { torch.bfloat16: torch.BFloat16Storage, } +# Storage of zipfile handles for each shard +torch_checkpoint_file_handles = {} class LazyTensor: pass @@ -101,7 +101,6 @@ class TorchLazyTensor(LazyTensor): stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None, - file_handle: Any = None ): self.storage_type = storage_type self.key = key @@ -112,7 +111,7 @@ class TorchLazyTensor(LazyTensor): self.stride = stride self.requires_grad = requires_grad self.backward_hooks = backward_hooks - self.file_handle = file_handle + self.file_name = None def __view(self, f: Callable): return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" @@ -127,7 +126,29 @@ class TorchLazyTensor(LazyTensor): no_grad=True, filename="pytorch_model.bin", ) -> torch.Tensor: - checkpoint = checkpoint or self.file_handle + + + + # if f not in torch_tensor_container_file_map: + # torch_tensor_container_file_map[f] = [] + + # with zipfile.ZipFile(f, "r") as z: + # paths = z.namelist() + + # for name in paths: + # val = name.split("/data/")[-1] + # if not val.isdecimal(): + # continue + # torch_tensor_container_file_map[f].append(int(val)) + # torch_tensor_container_file_map[f].sort() + # print(torch_tensor_container_file_map) + + + + + if not checkpoint: + checkpoint = torch_checkpoint_file_handles[self.file_name] + filename = self.file_name filename = os.path.basename(os.path.normpath(filename)).split(".")[0] size = reduce(lambda x, y: x * y, self.shape, 1) @@ -141,19 +162,23 @@ class TorchLazyTensor(LazyTensor): >> 3 ) ) + if isinstance(checkpoint, zipfile.ZipFile): try: f = checkpoint.open(f"archive/data/{self.key}", "r") except: f = checkpoint.open(f"{filename}/data/{self.key}", "r") - f.read(self.seek_offset) + f.seek(self.seek_offset, os.SEEK_CUR) + # f.read(self.seek_offset) else: f = checkpoint + try: storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little") finally: if isinstance(checkpoint, zipfile.ZipFile): f.close() + storage = torch.serialization._get_restore_location(map_location)( storage, self.location ) @@ -242,7 +267,6 @@ class _LazyUnpickler(RestrictedUnpickler): def __init__(self, *args, **kwargs): # print(args, kwargs) - self.file_handle = args[0] self.lazy_loaded_storages = {} return super().__init__(*args, **kwargs) @@ -253,7 +277,9 @@ class _LazyUnpickler(RestrictedUnpickler): typename == "storage" ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" storage_type, key, location, _ = saved_id[1:] - return TorchLazyTensor(storage_type, key, location, file_handle=self.file_handle) + return TorchLazyTensor( + storage_type, key, location + ) def load(self, *args, **kwargs): retval = super().load(*args, **kwargs) @@ -277,117 +303,6 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): return lazy_storage -# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 -def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): - for hook in self._load_state_dict_pre_hooks.values(): - hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - persistent_buffers = { - k: v - for k, v in self._buffers.items() - if k not in self._non_persistent_buffers_set - } - local_name_params = itertools.chain( - self._parameters.items(), persistent_buffers.items() - ) - local_state = {k: v for k, v in local_name_params if v is not None} - - for name, param in local_state.items(): - key = prefix + name - if key in state_dict: - input_param = state_dict[key] - if not torch.overrides.is_tensor_like(input_param): - error_msgs.append( - 'While copying the parameter named "{}", ' - "expected torch.Tensor or Tensor-like object from checkpoint but " - "received {}".format(key, type(input_param)) - ) - continue - - # This is used to avoid copying uninitialized parameters into - # non-lazy modules, since they dont have the hook to do the checks - # in such case, it will error when accessing the .shape attribute. - is_param_lazy = torch.nn.parameter.is_lazy(param) - # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if ( - not is_param_lazy - and len(param.shape) == 0 - and len(input_param.shape) == 1 - ): - input_param = input_param[0] - - if not is_param_lazy and input_param.shape != param.shape: - # local shape should match the one in checkpoint - error_msgs.append( - "size mismatch for {}: copying a param with shape {} from checkpoint, " - "the shape in current model is {}.".format( - key, input_param.shape, param.shape - ) - ) - continue - try: - with torch.no_grad(): - # param.copy_(input_param) - new_param = torch.nn.Parameter( - input_param, requires_grad=param.requires_grad - ) # This line is new - if name in self._parameters: # This line is new - self._parameters[name] = new_param # This line is new - if name in persistent_buffers: # This line is new - self._buffers[name] = new_param # This line is new - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format( - key, param.size(), input_param.size(), ex.args - ) - ) - elif strict: - missing_keys.append(key) - - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - hasattr(Module, "set_extra_state") - and getattr(self.__class__, "set_extra_state", Module.set_extra_state) - is not Module.set_extra_state - ): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: - if extra_state_key in state_dict: - self.set_extra_state(state_dict[extra_state_key]) - elif strict: - missing_keys.append(extra_state_key) - elif strict and (extra_state_key in state_dict): - unexpected_keys.append(extra_state_key) - - if strict: - for key in state_dict.keys(): - if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix) :] - input_name = input_name.split(".", 1)[ - 0 - ] # get the name of param/buffer/child - if input_name not in self._modules and input_name not in local_state: - unexpected_keys.append(key) - - def safetensors_load_tensor_independently( checkpoint_file: str, tensor_key: str, device: Any ) -> torch.Tensor: @@ -418,7 +333,9 @@ def patch_safetensors(callback): tensors = {} with safetensors.safe_open( - checkpoint_file, framework="pt", device=intermediary_device, + checkpoint_file, + framework="pt", + device=intermediary_device, ) as f: for key in f.keys(): tensors[key] = None @@ -426,7 +343,9 @@ def patch_safetensors(callback): for key in tensors.keys(): tensors[key] = SafetensorsLazyTensor( - checkpoint_file=checkpoint_file, key=key, location=intermediary_device, + checkpoint_file=checkpoint_file, + key=key, + location=intermediary_device, ) if callback is not None: @@ -442,6 +361,13 @@ def patch_safetensors(callback): transformers.modeling_utils.safe_load_file = safetensors_load +def get_torch_tensor_file(file: str, lazy_tensor: TorchLazyTensor): + with zipfile.ZipFile(file, "r") as z: + storage_key = lazy_tensor.key + ziproot = z.namelist()[0].split("/")[0] + f = z.open(f"{ziproot}/data/{storage_key}") + # TODO: Maybe some file seeking + return f @contextlib.contextmanager def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler): @@ -468,7 +394,6 @@ def use_lazy_load( enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, - use_accelerate_init_empty_weights=False, ): if not enable: with use_custom_unpickler(RestrictedUnpickler): @@ -483,22 +408,30 @@ def use_lazy_load( old_torch_load = torch.load def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): - retval = old_torch_load( + model_dict = old_torch_load( f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args, ) + + if f not in torch_checkpoint_file_handles: + torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r") + + for k,v in model_dict.items(): + v.file_name = f + if callback is not None: callback( - retval, + model_dict, f=f, map_location=map_location, pickle_module=pickle_module, is_safetensors=False, **pickle_load_args, ) - return retval + + return model_dict torch.load = torch_load @@ -506,30 +439,10 @@ def use_lazy_load( patch_safetensors(callback) if dematerialized_modules: - if use_accelerate_init_empty_weights: - import accelerate + import accelerate - init_empty_weights = accelerate.init_empty_weights() - init_empty_weights.__enter__() - else: - old_linear_init = torch.nn.Linear.__init__ - old_embedding_init = torch.nn.Embedding.__init__ - old_layernorm_init = torch.nn.LayerNorm.__init__ - - def linear_init(self, *args, device=None, **kwargs): - return old_linear_init(self, *args, device="meta", **kwargs) - - def embedding_init(self, *args, device=None, **kwargs): - return old_embedding_init(self, *args, device="meta", **kwargs) - - def layernorm_init(self, *args, device=None, **kwargs): - return old_layernorm_init(self, *args, device="meta", **kwargs) - - torch.nn.Linear.__init__ = linear_init - torch.nn.Embedding.__init__ = embedding_init - torch.nn.LayerNorm.__init__ = layernorm_init - old_load_from_state_dict = torch.nn.Module._load_from_state_dict - torch.nn.Module._load_from_state_dict = _load_from_state_dict + init_empty_weights = accelerate.init_empty_weights() + init_empty_weights.__enter__() with use_custom_unpickler(_LazyUnpickler): yield True @@ -538,10 +451,4 @@ def use_lazy_load( torch._utils._rebuild_tensor = old_rebuild_tensor torch.load = old_torch_load if dematerialized_modules: - if use_accelerate_init_empty_weights: - init_empty_weights.__exit__(None, None, None) - else: - torch.nn.Linear.__init__ = old_linear_init - torch.nn.Embedding.__init__ = old_embedding_init - torch.nn.LayerNorm.__init__ = old_layernorm_init - torch.nn.Module._load_from_state_dict = old_load_from_state_dict + init_empty_weights.__exit__(None, None, None) diff --git a/modeling/patches.py b/modeling/patches.py index c8c070f8..23d0301c 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -190,7 +190,7 @@ def patch_transformers_for_lazyload() -> None: # BEGIN PATCH if isinstance(param, LazyTensor): - print("Materializing", param_name) + print(".", end="", flush=True) param = param.materialize() # END PATCH diff --git a/prompt_tuner.py b/prompt_tuner.py index b1cbf78d..174a2391 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -855,7 +855,7 @@ class TrainerBase(abc.ABC): lazy_load_callback.nested = False # Since we're using lazy loader, we need to figure out what the model's hidden layers are called - with lazy_loader.use_lazy_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True): + with lazy_loader.use_lazy_load(dematerialized_modules=True): try: metamodel = AutoModelForCausalLM.from_config(model_config) except Exception as e: