mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Allow TPU Colab to load sharded HF models
This commit is contained in:
20
aiserver.py
20
aiserver.py
@ -1127,13 +1127,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import modeling_utils
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
# Lazy loader
|
||||
import torch_lazy_loader
|
||||
@ -1170,7 +1178,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
last_storage_key = None
|
||||
f = None
|
||||
current_offset = 0
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||
last_storage_key = storage_key
|
||||
@ -1560,13 +1570,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
else:
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import modeling_utils
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
def tpumtjgetsofttokens():
|
||||
soft_tokens = None
|
||||
|
Reference in New Issue
Block a user