From 09bb1021ddc548e4422d6426fe2c1867b6d152b8 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 23 Jul 2023 07:14:23 +0200 Subject: [PATCH] Fallback to transformers if hf_bleeding_edge not available --- modeling/inference_models/generic_hf_torch/class.py | 5 ++++- modeling/inference_models/gptq_hf_torch/class.py | 7 +++++-- modeling/inference_models/hf.py | 5 ++++- modeling/inference_models/hf_torch.py | 5 ++++- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index de89034b..5471ae43 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -7,7 +7,10 @@ import shutil from typing import Union from transformers import GPTNeoForCausalLM, GPT2LMHeadModel, BitsAndBytesConfig -from hf_bleeding_edge import AutoModelForCausalLM +try: + from hf_bleeding_edge import AutoModelForCausalLM +except ImportError: + from transformers import AutoModelForCausalLM from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME diff --git a/modeling/inference_models/gptq_hf_torch/class.py b/modeling/inference_models/gptq_hf_torch/class.py index 157ebdbe..0819c8ae 100644 --- a/modeling/inference_models/gptq_hf_torch/class.py +++ b/modeling/inference_models/gptq_hf_torch/class.py @@ -10,8 +10,11 @@ import sys from typing import Union from transformers import GPTNeoForCausalLM, AutoTokenizer, LlamaTokenizer -import hf_bleeding_edge -from hf_bleeding_edge import AutoModelForCausalLM +try: + import hf_bleeding_edge + from hf_bleeding_edge import AutoModelForCausalLM +except ImportError: + from transformers import AutoModelForCausalLM import utils import modeling.lazy_loader as lazy_loader diff --git a/modeling/inference_models/hf.py b/modeling/inference_models/hf.py index cd55c3ef..be0fb059 100644 --- a/modeling/inference_models/hf.py +++ b/modeling/inference_models/hf.py @@ -1,6 +1,9 @@ import os, sys from typing import Optional -from hf_bleeding_edge import AutoConfig +try: + from hf_bleeding_edge import AutoConfig +except ImportError: + from transformers import AutoConfig import warnings import utils diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index f7bd7a0b..6372858f 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -19,7 +19,10 @@ from transformers import ( GPT2LMHeadModel, LogitsProcessorList, ) -from hf_bleeding_edge import AutoModelForCausalLM +try: + from hf_bleeding_edge import AutoModelForCausalLM +except ImportError: + from transformers import AutoModelForCausalLM import utils import modeling.lazy_loader as lazy_loader