mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Loading a sharded model will now display only one progress bar
This commit is contained in:
		
							
								
								
									
										29
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -1170,6 +1170,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|         def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||||
|             utils.num_shards = None | ||||
|             utils.current_shard = 0 | ||||
|             utils.from_pretrained_model_name = pretrained_model_name_or_path | ||||
|             utils.from_pretrained_index_filename = None | ||||
|             utils.from_pretrained_kwargs = kwargs | ||||
|             utils.bar = None | ||||
|             if not args.no_aria2: | ||||
|                 utils.aria2_hook(pretrained_model_name_or_path, **kwargs) | ||||
|             return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) | ||||
| @@ -1177,6 +1181,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|         old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files | ||||
|         def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): | ||||
|             utils.num_shards = utils.get_num_shards(index_filename) | ||||
|             utils.from_pretrained_index_filename = index_filename | ||||
|             return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) | ||||
|         modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files | ||||
|  | ||||
| @@ -1196,6 +1201,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|                 ram_blocks = gpu_blocks = cumulative_gpu_blocks = None | ||||
|  | ||||
|             def lazy_load_callback(model_dict, f, **_): | ||||
|                 if lazy_load_callback.nested: | ||||
|                     return | ||||
|                 lazy_load_callback.nested = True | ||||
|  | ||||
|                 device_map = {} | ||||
|  | ||||
|                 for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): | ||||
| @@ -1210,6 +1219,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|                     if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map: | ||||
|                         device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" | ||||
|  | ||||
|                 if utils.num_shards is None or utils.current_shard == 0: | ||||
|                     if utils.num_shards is not None: | ||||
|                         num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) | ||||
|                     else: | ||||
|                         num_tensors = len(device_map) | ||||
|                     utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") | ||||
|  | ||||
|                 with zipfile.ZipFile(f, "r") as z: | ||||
|                     try: | ||||
|                         last_storage_key = None | ||||
| @@ -1217,7 +1233,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|                         current_offset = 0 | ||||
|                         if utils.num_shards is not None: | ||||
|                             utils.current_shard += 1 | ||||
|                         for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): | ||||
|                         for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): | ||||
|                             storage_key = model_dict[key].key | ||||
|                             if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: | ||||
|                                 last_storage_key = storage_key | ||||
| @@ -1241,10 +1257,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|                             model_dict[key] = model_dict[key].to(device) | ||||
|                             #print("OK", flush=True) | ||||
|                             current_offset += nbytes | ||||
|                             utils.bar.update(1) | ||||
|                     finally: | ||||
|                         if utils.num_shards is None or utils.current_shard >= utils.num_shards: | ||||
|                             utils.bar.close() | ||||
|                             utils.bar = None | ||||
|                         lazy_load_callback.nested = False | ||||
|                         if isinstance(f, zipfile.ZipExtFile): | ||||
|                             f.close() | ||||
|  | ||||
|             lazy_load_callback.nested = False | ||||
|             return lazy_load_callback | ||||
|  | ||||
|         lazy_load_config_path = os.path.join("maps", vars.model_type + ".json") | ||||
| @@ -1640,6 +1662,10 @@ else: | ||||
|     def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||||
|         utils.num_shards = None | ||||
|         utils.current_shard = 0 | ||||
|         utils.from_pretrained_model_name = pretrained_model_name_or_path | ||||
|         utils.from_pretrained_index_filename = None | ||||
|         utils.from_pretrained_kwargs = kwargs | ||||
|         utils.bar = None | ||||
|         if not args.no_aria2: | ||||
|             utils.aria2_hook(pretrained_model_name_or_path, **kwargs) | ||||
|         return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) | ||||
| @@ -1647,6 +1673,7 @@ else: | ||||
|     old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files | ||||
|     def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): | ||||
|         utils.num_shards = utils.get_num_shards(index_filename) | ||||
|         utils.from_pretrained_index_filename = index_filename | ||||
|         return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) | ||||
|     modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files | ||||
|  | ||||
|   | ||||
| @@ -1160,6 +1160,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|     import functools | ||||
|  | ||||
|     def callback(model_dict, f, **_): | ||||
|         if callback.nested: | ||||
|             return | ||||
|         callback.nested = True | ||||
|         with zipfile.ZipFile(f, "r") as z: | ||||
|             try: | ||||
|                 last_storage_key = None | ||||
| @@ -1167,9 +1170,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|                 current_offset = 0 | ||||
|                 if utils.current_shard == 0: | ||||
|                     print("\n\n\nThis model has  ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), "  parameters.\n") | ||||
|  | ||||
|                 if utils.num_shards is None or utils.current_shard == 0: | ||||
|                     if utils.num_shards is not None: | ||||
|                         num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) | ||||
|                     else: | ||||
|                         num_tensors = len(model_dict) | ||||
|                     utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") | ||||
|  | ||||
|                 if utils.num_shards is not None: | ||||
|                     utils.current_shard += 1 | ||||
|                 for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): | ||||
|                 for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): | ||||
|  | ||||
|                     # Some model weights are used by transformers but not by MTJ. | ||||
|                     # We have to materialize these weights anyways because | ||||
| @@ -1178,6 +1189,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|                     # tensors, which don't take up any actual CPU or TPU memory. | ||||
|                     if key not in model_spec: | ||||
|                         model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") | ||||
|                         utils.bar.update(1) | ||||
|                         continue | ||||
|  | ||||
|                     storage_key = model_dict[key].key | ||||
| @@ -1230,6 +1242,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|                         np.empty(params["cores_per_replica"]), | ||||
|                     ) | ||||
|  | ||||
|                     utils.bar.update(1) | ||||
|  | ||||
|                 if utils.num_shards is not None and utils.current_shard < utils.num_shards: | ||||
|                     return | ||||
|  | ||||
| @@ -1250,9 +1264,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo | ||||
|                                 error = f"{mk} {pk} could not be found in the model checkpoint" | ||||
|                                 print("\n\nERROR:  " + error, file=sys.stderr) | ||||
|                                 raise RuntimeError(error) | ||||
|             except: | ||||
|                 import traceback | ||||
|                 traceback.print_exc() | ||||
|             finally: | ||||
|                 if utils.num_shards is None or utils.current_shard >= utils.num_shards: | ||||
|                     utils.bar.close() | ||||
|                     utils.bar = None | ||||
|                 callback.nested = False | ||||
|                 if isinstance(f, zipfile.ZipExtFile): | ||||
|                     f.close() | ||||
|     callback.nested = False | ||||
|  | ||||
|     if os.path.isdir(vars.model.replace('/', '_')): | ||||
|         import shutil | ||||
|   | ||||
							
								
								
									
										16
									
								
								utils.py
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								utils.py
									
									
									
									
									
								
							| @@ -9,11 +9,16 @@ import requests.adapters | ||||
| import time | ||||
| from tqdm.auto import tqdm | ||||
| import os | ||||
| import itertools | ||||
| from typing import Optional | ||||
|  | ||||
| vars = None | ||||
| num_shards: Optional[int] = None | ||||
| current_shard = 0 | ||||
| from_pretrained_model_name = "" | ||||
| from_pretrained_index_filename: Optional[str] = None | ||||
| from_pretrained_kwargs = {} | ||||
| bar = None | ||||
|  | ||||
| #==================================================================# | ||||
| # Decorator to prevent a function's actions from being run until | ||||
| @@ -280,3 +285,14 @@ def get_num_shards(filename): | ||||
|     with open(filename) as f: | ||||
|         map_data = json.load(f) | ||||
|     return len(set(map_data["weight_map"].values())) | ||||
|  | ||||
| #==================================================================# | ||||
| #  Given the name/path of a sharded model and the path to a | ||||
| #  pytorch_model.bin.index.json, returns a list of weight names in the | ||||
| #  sharded model.  Requires lazy loader to be enabled to work properl | ||||
| #==================================================================# | ||||
| def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): | ||||
|     import transformers.modeling_utils | ||||
|     import torch | ||||
|     shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror) | ||||
|     return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths))) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann