mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix loading bar
This commit is contained in:
@@ -159,7 +159,7 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
):
|
):
|
||||||
# Cache miss. Assuming weights are loaded in order of
|
# Cache miss. Assuming weights are loaded in order of
|
||||||
# (key, seek_offset), this means we need to invalidate the cache.
|
# (key, seek_offset), this means we need to invalidate the cache.
|
||||||
print("!", end="", flush=True)
|
# print("!", end="", flush=True)
|
||||||
CheckpointChunkCache.hit_data["misses"] += 1
|
CheckpointChunkCache.hit_data["misses"] += 1
|
||||||
|
|
||||||
CheckpointChunkCache.clear()
|
CheckpointChunkCache.clear()
|
||||||
@@ -176,7 +176,7 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Cache hit. Hip hip hooray! :^)
|
# Cache hit. Hip hip hooray! :^)
|
||||||
print(".", end="", flush=True)
|
# print(".", end="", flush=True)
|
||||||
CheckpointChunkCache.hit_data["hits"] += 1
|
CheckpointChunkCache.hit_data["hits"] += 1
|
||||||
|
|
||||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||||
@@ -481,6 +481,11 @@ def post_load_cleanup() -> None:
|
|||||||
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
|
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
|
||||||
CheckpointChunkCache.clear(unload_model=True)
|
CheckpointChunkCache.clear(unload_model=True)
|
||||||
|
|
||||||
|
# Bar is initialized in
|
||||||
|
# patches.patch_transformers_for_lazyload._load_state_dict_into_meta_model,
|
||||||
|
# as it has access to the state dict (for getting tensor count)
|
||||||
|
utils.bar = None
|
||||||
|
|
||||||
for v in torch_checkpoint_file_handles.values():
|
for v in torch_checkpoint_file_handles.values():
|
||||||
v.close()
|
v.close()
|
||||||
torch_checkpoint_file_handles = {}
|
torch_checkpoint_file_handles = {}
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from ctypes import Union
|
|
||||||
import requests
|
import requests
|
||||||
from typing import Iterable, List, Optional
|
from typing import List
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -14,7 +13,6 @@ from transformers import (
|
|||||||
from modeling.lazy_loader import LazyTensor
|
from modeling.lazy_loader import LazyTensor
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from logger import logger
|
|
||||||
|
|
||||||
|
|
||||||
def patch_transformers_download():
|
def patch_transformers_download():
|
||||||
@@ -203,6 +201,8 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
state_dict[new_key] = state_dict.pop(old_key)
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
# BEGIN PATCH
|
# BEGIN PATCH
|
||||||
|
utils.bar = tqdm(total=len(state_dict), desc="Loading model tensors", file=utils.UIProgressBarFile())
|
||||||
|
|
||||||
for param_name, param in sorted(
|
for param_name, param in sorted(
|
||||||
state_dict.items(),
|
state_dict.items(),
|
||||||
# State dict must be ordered in this manner to make the caching in
|
# State dict must be ordered in this manner to make the caching in
|
||||||
@@ -217,6 +217,7 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
if isinstance(param, LazyTensor):
|
if isinstance(param, LazyTensor):
|
||||||
# Should always be true
|
# Should always be true
|
||||||
param = param.materialize()
|
param = param.materialize()
|
||||||
|
utils.bar.update(1)
|
||||||
# END PATCH
|
# END PATCH
|
||||||
|
|
||||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||||
|
Reference in New Issue
Block a user