diff --git a/prompt_tuner.py b/prompt_tuner.py index 99320861..1ce3e210 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -78,8 +78,10 @@ class UniversalPromptTuningMixin: Embedding.__call__ = old_embedding_call for k in dir(GPTPromptTuningMixin): - if not hasattr(UniversalPromptTuningMixin, k): - setattr(UniversalPromptTuningMixin, k, getattr(GPTPromptTuningMixin, k)) + v = getattr(GPTPromptTuningMixin, k) + _v = getattr(UniversalPromptTuningMixin, k, None) + if _v is None or (_v is getattr(object, k, None) and callable(_v) and not isinstance(_v, type)): + setattr(UniversalPromptTuningMixin, k, v) class AutoPromptTuningLM(UniversalPromptTuningMixin, transformers.AutoModelForCausalLM):