mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Allow TPU Colab to load sharded HF models
This commit is contained in:
		
							
								
								
									
										20
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -1127,13 +1127,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|         from transformers import __version__ as transformers_version | ||||
|  | ||||
|         from transformers import PreTrainedModel | ||||
|         from transformers import modeling_utils | ||||
|         old_from_pretrained = PreTrainedModel.from_pretrained.__func__ | ||||
|         @classmethod | ||||
|         def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||||
|             utils.num_shards = None | ||||
|             utils.current_shard = 0 | ||||
|             if not args.no_aria2: | ||||
|                 utils.aria2_hook(pretrained_model_name_or_path, **kwargs) | ||||
|             return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) | ||||
|         PreTrainedModel.from_pretrained = new_from_pretrained | ||||
|         old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files | ||||
|         def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): | ||||
|             utils.num_shards = utils.get_num_shards(index_filename) | ||||
|             return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) | ||||
|         modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files | ||||
|  | ||||
|         # Lazy loader | ||||
|         import torch_lazy_loader | ||||
| @@ -1170,7 +1178,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|                         last_storage_key = None | ||||
|                         f = None | ||||
|                         current_offset = 0 | ||||
|                         for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): | ||||
|                         if utils.num_shards is not None: | ||||
|                             utils.current_shard += 1 | ||||
|                         for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): | ||||
|                             storage_key = model_dict[key].key | ||||
|                             if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: | ||||
|                                 last_storage_key = storage_key | ||||
| @@ -1560,13 +1570,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|         tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") | ||||
| else: | ||||
|     from transformers import PreTrainedModel | ||||
|     from transformers import modeling_utils | ||||
|     old_from_pretrained = PreTrainedModel.from_pretrained.__func__ | ||||
|     @classmethod | ||||
|     def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||||
|         utils.num_shards = None | ||||
|         utils.current_shard = 0 | ||||
|         if not args.no_aria2: | ||||
|             utils.aria2_hook(pretrained_model_name_or_path, **kwargs) | ||||
|         return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) | ||||
|     PreTrainedModel.from_pretrained = new_from_pretrained | ||||
|     old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files | ||||
|     def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): | ||||
|         utils.num_shards = utils.get_num_shards(index_filename) | ||||
|         return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) | ||||
|     modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files | ||||
|  | ||||
|     def tpumtjgetsofttokens(): | ||||
|         soft_tokens = None | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann