mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix patching code of PreTrainedModel.from_pretrained()
This commit is contained in:
14
aiserver.py
14
aiserver.py
@@ -1114,10 +1114,11 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
from transformers import __version__ as transformers_version
|
from transformers import __version__ as transformers_version
|
||||||
|
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
old_from_pretrained = PreTrainedModel.from_pretrained
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
@classmethod
|
||||||
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
|
||||||
# Lazy loader
|
# Lazy loader
|
||||||
@@ -1545,10 +1546,11 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||||
else:
|
else:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
old_from_pretrained = PreTrainedModel.from_pretrained
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
@classmethod
|
||||||
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
|
||||||
def tpumtjgetsofttokens():
|
def tpumtjgetsofttokens():
|
||||||
|
Reference in New Issue
Block a user