mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add checkpoint tracking for loading
With the index of checkpoint files here (and with the total_size in the index json) we could probably have a cleaner per-byte loading bar in the future, but let's not break anything for now.
This commit is contained in:
@@ -687,8 +687,9 @@ class settings(object):
|
||||
|
||||
class model_settings(settings):
|
||||
local_only_variables = ['apikey', 'default_preset']
|
||||
no_save_variables = ['modelconfig', 'custmodpth', 'generated_tkns',
|
||||
'loaded_layers', 'total_layers', 'total_download_chunks', 'downloaded_chunks', 'presets', 'default_preset',
|
||||
no_save_variables = ['modelconfig', 'custmodpth', 'generated_tkns',
|
||||
'loaded_layers', 'total_layers', 'loaded_checkpoints', 'total_checkpoints',
|
||||
'total_download_chunks', 'downloaded_chunks', 'presets', 'default_preset',
|
||||
'welcome', 'welcome_default', 'simple_randomness', 'simple_creativity', 'simple_repitition',
|
||||
'badwordsids', 'uid_presets', 'model', 'model_type', 'lazy_load', 'fp32_model', 'modeldim', 'horde_wait_time', 'horde_queue_position', 'horde_queue_size', 'newlinemode', 'tqdm_progress', 'tqdm_rem_time', '_tqdm']
|
||||
settings_name = "model"
|
||||
@@ -705,6 +706,8 @@ class model_settings(settings):
|
||||
self.generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
|
||||
self.loaded_layers = 0 # Used in UI 2 to show model loading progress
|
||||
self.total_layers = 0 # Same as above
|
||||
self.loaded_checkpoints = 0
|
||||
self.total_checkpoints = 0
|
||||
self.total_download_chunks = 0 # tracks how much of the model has downloaded for the UI 2
|
||||
self.downloaded_chunks = 0 #as above
|
||||
self._tqdm = tqdm.tqdm(total=self.genamt, file=self.ignore_tqdm()) # tqdm agent for generating tokens. This will allow us to calculate the remaining time
|
||||
|
@@ -52,7 +52,8 @@ import zipfile
|
||||
import pickle
|
||||
import torch
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from torch.nn import Module
|
||||
from torch.storage import UntypedStorage
|
||||
@@ -398,6 +399,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def get_sharded_torch_checkpoints(dir: str) -> List[str]:
|
||||
try:
|
||||
with open(os.path.join(dir, "pytorch_model.bin.index.json")) as file:
|
||||
j = json.load(file)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
||||
try:
|
||||
return list(set(j["weight_map"].values()))
|
||||
except KeyError:
|
||||
return []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_lazy_load(
|
||||
enable=True,
|
||||
@@ -410,6 +423,8 @@ def use_lazy_load(
|
||||
return
|
||||
|
||||
begin_time = time.time()
|
||||
utils.koboldai_vars.total_checkpoints = 0
|
||||
utils.koboldai_vars.loaded_checkpoints = 0
|
||||
|
||||
try:
|
||||
LazyloadPatches.__enter__()
|
||||
@@ -421,6 +436,14 @@ def use_lazy_load(
|
||||
old_torch_load = torch.load
|
||||
|
||||
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
if not utils.koboldai_vars.total_checkpoints:
|
||||
checkpoints = get_sharded_torch_checkpoints(os.path.dirname(f))
|
||||
# `checkpoints` may be empty if there is an error--return 1 in
|
||||
# this case. Either there was no checkpoint index file (most
|
||||
# common case), or there was a compatibility issue while reading
|
||||
# it.
|
||||
utils.koboldai_vars.total_checkpoints = len(checkpoints) or 1
|
||||
|
||||
model_dict = old_torch_load(
|
||||
f=f,
|
||||
map_location=map_location,
|
||||
|
@@ -326,6 +326,7 @@ class LazyloadPatches:
|
||||
fp16_statistics=fp16_statistics,
|
||||
)
|
||||
|
||||
utils.koboldai_vars.loaded_checkpoints += 1
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user