mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Loading a sharded model will now display only one progress bar
This commit is contained in:
29
aiserver.py
29
aiserver.py
@ -1170,6 +1170,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
utils.num_shards = None
|
utils.num_shards = None
|
||||||
utils.current_shard = 0
|
utils.current_shard = 0
|
||||||
|
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
||||||
|
utils.from_pretrained_index_filename = None
|
||||||
|
utils.from_pretrained_kwargs = kwargs
|
||||||
|
utils.bar = None
|
||||||
if not args.no_aria2:
|
if not args.no_aria2:
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
@ -1177,6 +1181,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||||
utils.num_shards = utils.get_num_shards(index_filename)
|
utils.num_shards = utils.get_num_shards(index_filename)
|
||||||
|
utils.from_pretrained_index_filename = index_filename
|
||||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||||
|
|
||||||
@ -1196,6 +1201,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
|
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
|
||||||
|
|
||||||
def lazy_load_callback(model_dict, f, **_):
|
def lazy_load_callback(model_dict, f, **_):
|
||||||
|
if lazy_load_callback.nested:
|
||||||
|
return
|
||||||
|
lazy_load_callback.nested = True
|
||||||
|
|
||||||
device_map = {}
|
device_map = {}
|
||||||
|
|
||||||
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
||||||
@ -1210,6 +1219,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
|
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
|
||||||
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
|
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
|
||||||
|
|
||||||
|
if utils.num_shards is None or utils.current_shard == 0:
|
||||||
|
if utils.num_shards is not None:
|
||||||
|
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||||
|
else:
|
||||||
|
num_tensors = len(device_map)
|
||||||
|
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||||
|
|
||||||
with zipfile.ZipFile(f, "r") as z:
|
with zipfile.ZipFile(f, "r") as z:
|
||||||
try:
|
try:
|
||||||
last_storage_key = None
|
last_storage_key = None
|
||||||
@ -1217,7 +1233,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
current_offset = 0
|
current_offset = 0
|
||||||
if utils.num_shards is not None:
|
if utils.num_shards is not None:
|
||||||
utils.current_shard += 1
|
utils.current_shard += 1
|
||||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
||||||
storage_key = model_dict[key].key
|
storage_key = model_dict[key].key
|
||||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||||
last_storage_key = storage_key
|
last_storage_key = storage_key
|
||||||
@ -1241,10 +1257,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
model_dict[key] = model_dict[key].to(device)
|
model_dict[key] = model_dict[key].to(device)
|
||||||
#print("OK", flush=True)
|
#print("OK", flush=True)
|
||||||
current_offset += nbytes
|
current_offset += nbytes
|
||||||
|
utils.bar.update(1)
|
||||||
finally:
|
finally:
|
||||||
|
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
|
||||||
|
utils.bar.close()
|
||||||
|
utils.bar = None
|
||||||
|
lazy_load_callback.nested = False
|
||||||
if isinstance(f, zipfile.ZipExtFile):
|
if isinstance(f, zipfile.ZipExtFile):
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
lazy_load_callback.nested = False
|
||||||
return lazy_load_callback
|
return lazy_load_callback
|
||||||
|
|
||||||
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
|
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
|
||||||
@ -1640,6 +1662,10 @@ else:
|
|||||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
utils.num_shards = None
|
utils.num_shards = None
|
||||||
utils.current_shard = 0
|
utils.current_shard = 0
|
||||||
|
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
||||||
|
utils.from_pretrained_index_filename = None
|
||||||
|
utils.from_pretrained_kwargs = kwargs
|
||||||
|
utils.bar = None
|
||||||
if not args.no_aria2:
|
if not args.no_aria2:
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
@ -1647,6 +1673,7 @@ else:
|
|||||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||||
utils.num_shards = utils.get_num_shards(index_filename)
|
utils.num_shards = utils.get_num_shards(index_filename)
|
||||||
|
utils.from_pretrained_index_filename = index_filename
|
||||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||||
|
|
||||||
|
@ -1160,6 +1160,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
def callback(model_dict, f, **_):
|
def callback(model_dict, f, **_):
|
||||||
|
if callback.nested:
|
||||||
|
return
|
||||||
|
callback.nested = True
|
||||||
with zipfile.ZipFile(f, "r") as z:
|
with zipfile.ZipFile(f, "r") as z:
|
||||||
try:
|
try:
|
||||||
last_storage_key = None
|
last_storage_key = None
|
||||||
@ -1167,9 +1170,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
current_offset = 0
|
current_offset = 0
|
||||||
if utils.current_shard == 0:
|
if utils.current_shard == 0:
|
||||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||||
|
|
||||||
|
if utils.num_shards is None or utils.current_shard == 0:
|
||||||
|
if utils.num_shards is not None:
|
||||||
|
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||||
|
else:
|
||||||
|
num_tensors = len(model_dict)
|
||||||
|
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||||
|
|
||||||
if utils.num_shards is not None:
|
if utils.num_shards is not None:
|
||||||
utils.current_shard += 1
|
utils.current_shard += 1
|
||||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
||||||
|
|
||||||
# Some model weights are used by transformers but not by MTJ.
|
# Some model weights are used by transformers but not by MTJ.
|
||||||
# We have to materialize these weights anyways because
|
# We have to materialize these weights anyways because
|
||||||
@ -1178,6 +1189,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
# tensors, which don't take up any actual CPU or TPU memory.
|
# tensors, which don't take up any actual CPU or TPU memory.
|
||||||
if key not in model_spec:
|
if key not in model_spec:
|
||||||
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
|
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
|
||||||
|
utils.bar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
storage_key = model_dict[key].key
|
storage_key = model_dict[key].key
|
||||||
@ -1230,6 +1242,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
np.empty(params["cores_per_replica"]),
|
np.empty(params["cores_per_replica"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
utils.bar.update(1)
|
||||||
|
|
||||||
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
|
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1250,9 +1264,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
error = f"{mk} {pk} could not be found in the model checkpoint"
|
error = f"{mk} {pk} could not be found in the model checkpoint"
|
||||||
print("\n\nERROR: " + error, file=sys.stderr)
|
print("\n\nERROR: " + error, file=sys.stderr)
|
||||||
raise RuntimeError(error)
|
raise RuntimeError(error)
|
||||||
|
except:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
|
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
|
||||||
|
utils.bar.close()
|
||||||
|
utils.bar = None
|
||||||
|
callback.nested = False
|
||||||
if isinstance(f, zipfile.ZipExtFile):
|
if isinstance(f, zipfile.ZipExtFile):
|
||||||
f.close()
|
f.close()
|
||||||
|
callback.nested = False
|
||||||
|
|
||||||
if os.path.isdir(vars.model.replace('/', '_')):
|
if os.path.isdir(vars.model.replace('/', '_')):
|
||||||
import shutil
|
import shutil
|
||||||
|
16
utils.py
16
utils.py
@ -9,11 +9,16 @@ import requests.adapters
|
|||||||
import time
|
import time
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import os
|
import os
|
||||||
|
import itertools
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
vars = None
|
vars = None
|
||||||
num_shards: Optional[int] = None
|
num_shards: Optional[int] = None
|
||||||
current_shard = 0
|
current_shard = 0
|
||||||
|
from_pretrained_model_name = ""
|
||||||
|
from_pretrained_index_filename: Optional[str] = None
|
||||||
|
from_pretrained_kwargs = {}
|
||||||
|
bar = None
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decorator to prevent a function's actions from being run until
|
# Decorator to prevent a function's actions from being run until
|
||||||
@ -280,3 +285,14 @@ def get_num_shards(filename):
|
|||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
map_data = json.load(f)
|
map_data = json.load(f)
|
||||||
return len(set(map_data["weight_map"].values()))
|
return len(set(map_data["weight_map"].values()))
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Given the name/path of a sharded model and the path to a
|
||||||
|
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||||
|
# sharded model. Requires lazy loader to be enabled to work properl
|
||||||
|
#==================================================================#
|
||||||
|
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||||
|
import transformers.modeling_utils
|
||||||
|
import torch
|
||||||
|
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
||||||
|
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||||
|
Reference in New Issue
Block a user