diff --git a/src/transformers/models/pixio/modeling_pixio.py b/src/transformers/models/pixio/modeling_pixio.py index e3f9e2626f27..985e6b51216f 100644 --- a/src/transformers/models/pixio/modeling_pixio.py +++ b/src/transformers/models/pixio/modeling_pixio.py @@ -132,10 +132,12 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: return torch.cat((class_pos_embed, patch_pos_embed), dim=1) - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embeddings.projection.weight.dtype - embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + embeddings = self.patch_embeddings( + pixel_values.to(dtype=target_dtype), interpolate_pos_encoding=interpolate_pos_encoding + ) cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) @@ -407,12 +409,13 @@ def get_input_embeddings(self) -> PixioPatchEmbeddings: def forward( self, pixel_values: torch.Tensor | None = None, + interpolate_pos_encoding: bool = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: if pixel_values is None: raise ValueError("You have to specify pixel_values") - embedding_output = self.embeddings(pixel_values) + embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs: BaseModelOutput = self.encoder(embedding_output, **kwargs) sequence_output = encoder_outputs.last_hidden_state diff --git a/src/transformers/models/pixio/modular_pixio.py b/src/transformers/models/pixio/modular_pixio.py index 3446ed222283..6e5e9a65375f 100644 --- a/src/transformers/models/pixio/modular_pixio.py +++ b/src/transformers/models/pixio/modular_pixio.py @@ -172,10 +172,12 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: return torch.cat((class_pos_embed, patch_pos_embed), dim=1) - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embeddings.projection.weight.dtype - embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + embeddings = self.patch_embeddings( + pixel_values.to(dtype=target_dtype), interpolate_pos_encoding=interpolate_pos_encoding + ) cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) @@ -274,12 +276,13 @@ def get_input_embeddings(self) -> PixioPatchEmbeddings: def forward( self, pixel_values: torch.Tensor | None = None, + interpolate_pos_encoding: bool = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: if pixel_values is None: raise ValueError("You have to specify pixel_values") - embedding_output = self.embeddings(pixel_values) + embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs: BaseModelOutput = self.encoder(embedding_output, **kwargs) sequence_output = encoder_outputs.last_hidden_state