diff --git a/aiserver.py b/aiserver.py index e926ae46..05f7a85a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1015,27 +1015,28 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme import transformers.generation_utils from transformers import __version__ as transformers_version - # Temporary fix for XGLM positional embedding issues until + # Some versions of transformers 4.17.0.dev0 are affected by # https://github.com/huggingface/transformers/issues/15736 - # is resolved - try: - from transformers.models.xglm.modeling_xglm import XGLMSinusoidalPositionalEmbedding - except ImportError: - pass - else: - @torch.no_grad() - def new_forward(self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0): - bsz, seq_len = inputs_embeds.size()[:-1] - input_shape = inputs_embeds.size()[:-1] - sequence_length = input_shape[1] - position_ids = torch.arange( - past_key_values_length + self.padding_idx + 1, past_key_values_length + sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device - ).unsqueeze(0).expand(input_shape).contiguous() - max_pos = self.padding_idx + 1 + seq_len + past_key_values_length - if max_pos > self.weights.size(0): - self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) - return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() - XGLMSinusoidalPositionalEmbedding.forward = new_forward + # This is a workaround for those versions of transformers. + if(transformers_version == "4.17.0.dev0"): + try: + from transformers.models.xglm.modeling_xglm import XGLMSinusoidalPositionalEmbedding + except ImportError: + pass + else: + @torch.no_grad() + def new_forward(self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0): + bsz, seq_len = inputs_embeds.size()[:-1] + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + position_ids = torch.arange( + past_key_values_length + self.padding_idx + 1, past_key_values_length + sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ).unsqueeze(0).expand(input_shape).contiguous() + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + XGLMSinusoidalPositionalEmbedding.forward = new_forward # Patch transformers to use our soft prompt def patch_causallm(cls):