Don't crash if XGLMSinusoidalPositionalEmbedding doesn't exist

This commit is contained in:
Gnome Ann
2022-02-20 17:41:00 -05:00
parent 5dc4969173
commit da10e2dc1d

View File

@@ -768,7 +768,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Temporary fix for XGLM positional embedding issues until # Temporary fix for XGLM positional embedding issues until
# https://github.com/huggingface/transformers/issues/15736 # https://github.com/huggingface/transformers/issues/15736
# is resolved # is resolved
try:
from transformers.models.xglm.modeling_xglm import XGLMSinusoidalPositionalEmbedding from transformers.models.xglm.modeling_xglm import XGLMSinusoidalPositionalEmbedding
except ImportError:
pass
else:
@torch.no_grad() @torch.no_grad()
def new_forward(self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0): def new_forward(self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0):
bsz, seq_len = inputs_embeds.size()[:-1] bsz, seq_len = inputs_embeds.size()[:-1]