Source code for opr.modules.feature_extractors.convnext
"""ConvNeXt-based image feature extractors."""
from torch import Tensor, nn
from torchvision.models import ConvNeXt_Tiny_Weights, convnext_tiny
[docs]
class ConvNeXtTinyFeatureExtractor(nn.Module):
"""ConvNeXt-Tiny image feature extractor."""
def __init__(
self,
in_channels: int = 3,
pretrained: bool = True,
) -> None:
"""ConvNeXt-Tiny image feature extractor.
Args:
in_channels (int): Number of input channels. Defaults to 3.
pretrained (bool): Whether to load ImageNet-pretrained model. Defaults to True.
Raises:
ValueError: If `in_channels` is not 3 and `pretrained` is True.
"""
super().__init__()
if in_channels != 3 and pretrained:
raise ValueError("Pretrained models are only available for 3-channel images")
model = convnext_tiny(weights=(ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None))
self.feature_extractor = model.features
# change input conv to accept n-channel images
if in_channels != 3:
self.feature_extractor[0][0] = nn.Conv2d(
in_channels=in_channels,
out_channels=self.feature_extractor[0][0].out_channels,
kernel_size=self.feature_extractor[0][0].kernel_size,
stride=self.feature_extractor[0][0].stride,
padding=self.feature_extractor[0][0].padding,
dilation=self.feature_extractor[0][0].dilation,
groups=self.feature_extractor[0][0].groups,
bias=True,
padding_mode=self.feature_extractor[0][0].padding_mode,
device=next(self.feature_extractor[0][0].parameters()).device,
dtype=next(self.feature_extractor[0][0].parameters()).dtype,
)
[docs]
def forward(self, image: Tensor) -> Tensor: # noqa: D102
return self.feature_extractor(image)