Not quite

This commit is contained in:
somebody
2023-05-28 14:57:45 -05:00
parent ed0728188a
commit ceaefa9f5e
4 changed files with 25 additions and 22 deletions

View File

@@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE.
import contextlib
from functools import reduce
import time
import zipfile
import pickle
import torch
@@ -55,6 +56,8 @@ import _codecs
import os
from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch.storage import UntypedStorage
# Safetensors is a dependency for the local version, TPU/Colab doesn't
# support it yet.
try:
@@ -65,21 +68,9 @@ except ModuleNotFoundError:
HAS_SAFETENSORS = False
import utils
from logger import logger
STORAGE_TYPE_MAP = {
torch.float64: torch.DoubleStorage,
torch.float32: torch.FloatStorage,
torch.float16: torch.HalfStorage,
torch.int64: torch.LongStorage,
torch.int32: torch.IntStorage,
torch.int16: torch.ShortStorage,
torch.int8: torch.CharStorage,
torch.uint8: torch.ByteStorage,
torch.bool: torch.BoolStorage,
torch.bfloat16: torch.BFloat16Storage,
}
# Storage of zipfile handles for each shard
torch_checkpoint_file_handles = {}
@@ -205,8 +196,8 @@ class TorchLazyTensor(LazyTensor):
assert isinstance(checkpoint, zipfile.ZipFile)
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
CheckpointChunkCache.handle.read(nbytes), "little"
storage = UntypedStorage.from_buffer(
CheckpointChunkCache.handle.read(nbytes), "little", dtype=self.dtype
)
storage = torch.serialization._get_restore_location(map_location)(
@@ -421,6 +412,8 @@ def use_lazy_load(
yield False
return
begin_time = time.time()
try:
old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor
@@ -471,17 +464,23 @@ def use_lazy_load(
finally:
torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load
post_load_cleanup()
logger.debug(
f"[lazy_load] Context closed in {round(time.time() - begin_time, 2)} seconds."
)
if dematerialized_modules:
init_empty_weights.__exit__(None, None, None)
def post_load_cleanup() -> None:
"""Close dangling file pointers and clear caches after the load is complete."""
global torch_checkpoint_file_handles
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
CheckpointChunkCache.clear()
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
CheckpointChunkCache.clear(unload_model=True)
for v in torch_checkpoint_file_handles.values():
v.close()
torch_checkpoint_file_handles = {}