Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
This commit is contained in:
parent
defbb53b68
commit
29bb3f569b
19
aiserver.py
19
aiserver.py
|
@ -1293,6 +1293,25 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
pass
|
||||
|
||||
|
||||
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
|
||||
if(transformers_version == "4.19.0"):
|
||||
try:
|
||||
from transformers import OPTForCausalLM, OPTModel
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
# This is the same as the original __init__ but with
|
||||
# config.hidden_size
|
||||
# replaced with
|
||||
# config.word_embed_proj_dim
|
||||
def new_init(self, config):
|
||||
super(OPTForCausalLM, self).__init__(config)
|
||||
self.model = OPTModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
OPTForCausalLM.__init__ = new_init
|
||||
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
|
||||
|
|
Loading…
Reference in New Issue