mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Loading a sharded model will now display only one progress bar
This commit is contained in:
29
aiserver.py
29
aiserver.py
@ -1170,6 +1170,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
||||
utils.from_pretrained_index_filename = None
|
||||
utils.from_pretrained_kwargs = kwargs
|
||||
utils.bar = None
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
@ -1177,6 +1181,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
utils.from_pretrained_index_filename = index_filename
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
@ -1196,6 +1201,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
|
||||
|
||||
def lazy_load_callback(model_dict, f, **_):
|
||||
if lazy_load_callback.nested:
|
||||
return
|
||||
lazy_load_callback.nested = True
|
||||
|
||||
device_map = {}
|
||||
|
||||
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
||||
@ -1210,6 +1219,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
|
||||
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
|
||||
|
||||
if utils.num_shards is None or utils.current_shard == 0:
|
||||
if utils.num_shards is not None:
|
||||
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||
else:
|
||||
num_tensors = len(device_map)
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
@ -1217,7 +1233,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
current_offset = 0
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||
last_storage_key = storage_key
|
||||
@ -1241,10 +1257,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
model_dict[key] = model_dict[key].to(device)
|
||||
#print("OK", flush=True)
|
||||
current_offset += nbytes
|
||||
utils.bar.update(1)
|
||||
finally:
|
||||
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
lazy_load_callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
||||
lazy_load_callback.nested = False
|
||||
return lazy_load_callback
|
||||
|
||||
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
|
||||
@ -1640,6 +1662,10 @@ else:
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
||||
utils.from_pretrained_index_filename = None
|
||||
utils.from_pretrained_kwargs = kwargs
|
||||
utils.bar = None
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
@ -1647,6 +1673,7 @@ else:
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
utils.from_pretrained_index_filename = index_filename
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
|
Reference in New Issue
Block a user