Fix a bug in OPTForCausalLM where self.lm_head is the wrong size

This commit is contained in:
Gnome Ann 2022-05-13 01:37:17 -04:00
parent defbb53b68
commit 29bb3f569b
1 changed files with 19 additions and 0 deletions

View File

@ -1293,6 +1293,25 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
pass pass
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
if(transformers_version == "4.19.0"):
try:
from transformers import OPTForCausalLM, OPTModel
except ImportError:
pass
else:
# This is the same as the original __init__ but with
# config.hidden_size
# replaced with
# config.word_embed_proj_dim
def new_init(self, config):
super(OPTForCausalLM, self).__init__(config)
self.model = OPTModel(config)
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.post_init()
OPTForCausalLM.__init__ = new_init
# Patch transformers to use our custom logit warpers # Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper