Add checkpoint tracking for loading

With the index of checkpoint files here (and with the total_size in the
index json) we could probably have a cleaner per-byte loading bar in the
future, but let's not break anything for now.
This commit is contained in:
somebody
2023-10-13 12:33:00 -05:00
parent c1a96593fd
commit 21e6d84810
3 changed files with 30 additions and 3 deletions

View File

@@ -688,7 +688,8 @@ class settings(object):
class model_settings(settings):
local_only_variables = ['apikey', 'default_preset']
no_save_variables = ['modelconfig', 'custmodpth', 'generated_tkns',
'loaded_layers', 'total_layers', 'total_download_chunks', 'downloaded_chunks', 'presets', 'default_preset',
'loaded_layers', 'total_layers', 'loaded_checkpoints', 'total_checkpoints',
'total_download_chunks', 'downloaded_chunks', 'presets', 'default_preset',
'welcome', 'welcome_default', 'simple_randomness', 'simple_creativity', 'simple_repitition',
'badwordsids', 'uid_presets', 'model', 'model_type', 'lazy_load', 'fp32_model', 'modeldim', 'horde_wait_time', 'horde_queue_position', 'horde_queue_size', 'newlinemode', 'tqdm_progress', 'tqdm_rem_time', '_tqdm']
settings_name = "model"
@@ -705,6 +706,8 @@ class model_settings(settings):
self.generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
self.loaded_layers = 0 # Used in UI 2 to show model loading progress
self.total_layers = 0 # Same as above
self.loaded_checkpoints = 0
self.total_checkpoints = 0
self.total_download_chunks = 0 # tracks how much of the model has downloaded for the UI 2
self.downloaded_chunks = 0 #as above
self._tqdm = tqdm.tqdm(total=self.genamt, file=self.ignore_tqdm()) # tqdm agent for generating tokens. This will allow us to calculate the remaining time

View File

@@ -52,7 +52,8 @@ import zipfile
import pickle
import torch
import os
from typing import Any, Callable, Dict, Optional, Tuple, Type
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from torch.nn import Module
from torch.storage import UntypedStorage
@@ -398,6 +399,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def get_sharded_torch_checkpoints(dir: str) -> List[str]:
try:
with open(os.path.join(dir, "pytorch_model.bin.index.json")) as file:
j = json.load(file)
except FileNotFoundError:
return []
try:
return list(set(j["weight_map"].values()))
except KeyError:
return []
@contextlib.contextmanager
def use_lazy_load(
enable=True,
@@ -410,6 +423,8 @@ def use_lazy_load(
return
begin_time = time.time()
utils.koboldai_vars.total_checkpoints = 0
utils.koboldai_vars.loaded_checkpoints = 0
try:
LazyloadPatches.__enter__()
@@ -421,6 +436,14 @@ def use_lazy_load(
old_torch_load = torch.load
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
if not utils.koboldai_vars.total_checkpoints:
checkpoints = get_sharded_torch_checkpoints(os.path.dirname(f))
# `checkpoints` may be empty if there is an error--return 1 in
# this case. Either there was no checkpoint index file (most
# common case), or there was a compatibility issue while reading
# it.
utils.koboldai_vars.total_checkpoints = len(checkpoints) or 1
model_dict = old_torch_load(
f=f,
map_location=map_location,

View File

@@ -326,6 +326,7 @@ class LazyloadPatches:
fp16_statistics=fp16_statistics,
)
utils.koboldai_vars.loaded_checkpoints += 1
return error_msgs, offload_index, state_dict_index