mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Loading a sharded model will now display only one progress bar
This commit is contained in:
@ -1160,6 +1160,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
import functools
|
||||
|
||||
def callback(model_dict, f, **_):
|
||||
if callback.nested:
|
||||
return
|
||||
callback.nested = True
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
@ -1167,9 +1170,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
current_offset = 0
|
||||
if utils.current_shard == 0:
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
|
||||
if utils.num_shards is None or utils.current_shard == 0:
|
||||
if utils.num_shards is not None:
|
||||
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||
else:
|
||||
num_tensors = len(model_dict)
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
||||
|
||||
# Some model weights are used by transformers but not by MTJ.
|
||||
# We have to materialize these weights anyways because
|
||||
@ -1178,6 +1189,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
# tensors, which don't take up any actual CPU or TPU memory.
|
||||
if key not in model_spec:
|
||||
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
|
||||
utils.bar.update(1)
|
||||
continue
|
||||
|
||||
storage_key = model_dict[key].key
|
||||
@ -1230,6 +1242,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
np.empty(params["cores_per_replica"]),
|
||||
)
|
||||
|
||||
utils.bar.update(1)
|
||||
|
||||
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
|
||||
return
|
||||
|
||||
@ -1250,9 +1264,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
error = f"{mk} {pk} could not be found in the model checkpoint"
|
||||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
except:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
callback.nested = False
|
||||
|
||||
if os.path.isdir(vars.model.replace('/', '_')):
|
||||
import shutil
|
||||
|
Reference in New Issue
Block a user