mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Not quite
This commit is contained in:
@@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import contextlib
|
||||
from functools import reduce
|
||||
import time
|
||||
import zipfile
|
||||
import pickle
|
||||
import torch
|
||||
@@ -55,6 +56,8 @@ import _codecs
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from torch.storage import UntypedStorage
|
||||
|
||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||
# support it yet.
|
||||
try:
|
||||
@@ -65,21 +68,9 @@ except ModuleNotFoundError:
|
||||
HAS_SAFETENSORS = False
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
|
||||
|
||||
STORAGE_TYPE_MAP = {
|
||||
torch.float64: torch.DoubleStorage,
|
||||
torch.float32: torch.FloatStorage,
|
||||
torch.float16: torch.HalfStorage,
|
||||
torch.int64: torch.LongStorage,
|
||||
torch.int32: torch.IntStorage,
|
||||
torch.int16: torch.ShortStorage,
|
||||
torch.int8: torch.CharStorage,
|
||||
torch.uint8: torch.ByteStorage,
|
||||
torch.bool: torch.BoolStorage,
|
||||
torch.bfloat16: torch.BFloat16Storage,
|
||||
}
|
||||
|
||||
# Storage of zipfile handles for each shard
|
||||
torch_checkpoint_file_handles = {}
|
||||
|
||||
@@ -205,8 +196,8 @@ class TorchLazyTensor(LazyTensor):
|
||||
assert isinstance(checkpoint, zipfile.ZipFile)
|
||||
|
||||
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
||||
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
|
||||
CheckpointChunkCache.handle.read(nbytes), "little"
|
||||
storage = UntypedStorage.from_buffer(
|
||||
CheckpointChunkCache.handle.read(nbytes), "little", dtype=self.dtype
|
||||
)
|
||||
|
||||
storage = torch.serialization._get_restore_location(map_location)(
|
||||
@@ -421,6 +412,8 @@ def use_lazy_load(
|
||||
yield False
|
||||
return
|
||||
|
||||
begin_time = time.time()
|
||||
|
||||
try:
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||
@@ -471,17 +464,23 @@ def use_lazy_load(
|
||||
finally:
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
torch.load = old_torch_load
|
||||
|
||||
post_load_cleanup()
|
||||
logger.debug(
|
||||
f"[lazy_load] Context closed in {round(time.time() - begin_time, 2)} seconds."
|
||||
)
|
||||
|
||||
if dematerialized_modules:
|
||||
init_empty_weights.__exit__(None, None, None)
|
||||
|
||||
|
||||
def post_load_cleanup() -> None:
|
||||
"""Close dangling file pointers and clear caches after the load is complete."""
|
||||
global torch_checkpoint_file_handles
|
||||
|
||||
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
|
||||
CheckpointChunkCache.clear()
|
||||
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
|
||||
CheckpointChunkCache.clear(unload_model=True)
|
||||
|
||||
for v in torch_checkpoint_file_handles.values():
|
||||
v.close()
|
||||
|
||||
torch_checkpoint_file_handles = {}
|
||||
|
Reference in New Issue
Block a user