Source code for opr.modules.feature_extractors.resnet
"""ResNet-based image feature extractors."""
from torch import Tensor, nn
from torchvision.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
[docs]
class ResNetFeatureExtractor(nn.Module):
"""ResNet-based image feature extractor."""
def __init__(
self,
model: nn.Module,
in_channels: int = 3,
pretrained: bool = True,
) -> None:
"""ResNet-based image feature extractor.
Args:
model (nn.Module): ResNet model to use as feature extractor.
in_channels (int): Number of input channels. Defaults to 3.
pretrained (bool): Whether to load ImageNet-pretrained model. Defaults to True.
Raises:
ValueError: If `in_channels` is not 3 and `pretrained` is True.
"""
super().__init__()
if in_channels != 3 and pretrained:
raise ValueError("Pretrained models are only available for 3-channel images")
# Last 2 blocks are AdaptiveAvgPool2d and Linear
self.resnet_fe = nn.ModuleList(list(model.children())[:-2])
# change input conv to accept n-channel images
if in_channels != 3:
self.resnet_fe[0] = nn.Conv2d(
in_channels=in_channels,
out_channels=self.resnet_fe[0].out_channels,
kernel_size=self.resnet_fe[0].kernel_size,
stride=self.resnet_fe[0].stride,
padding=self.resnet_fe[0].padding,
dilation=self.resnet_fe[0].dilation,
groups=self.resnet_fe[0].groups,
bias=self.resnet_fe[0].bias,
padding_mode=self.resnet_fe[0].padding_mode,
device=next(self.resnet_fe[0].parameters()).device,
dtype=next(self.resnet_fe[0].parameters()).dtype,
)
[docs]
def forward(self, image: Tensor) -> Tensor: # noqa: D102
x = image
for layer in self.resnet_fe:
x = layer(x)
return x
[docs]
class ResNet18FeatureExtractor(ResNetFeatureExtractor):
"""ResNet18 image feature extractor."""
def __init__(self, in_channels: int = 3, pretrained: bool = True) -> None:
"""ResNet18 image feature extractor.
Args:
in_channels (int): Number of input channels. Defaults to 3.
pretrained (bool): Whether to load ImageNet-pretrained model. Defaults to True.
"""
model = resnet18(weights=(ResNet18_Weights.IMAGENET1K_V1 if pretrained else None))
super().__init__(model=model, in_channels=in_channels, pretrained=pretrained)
[docs]
class ResNet50FeatureExtractor(ResNetFeatureExtractor):
"""ResNet50 image feature extractor."""
def __init__(self, in_channels: int = 3, pretrained: bool = True) -> None:
"""ResNet50 image feature extractor.
Args:
in_channels (int): Number of input channels. Defaults to 3.
pretrained (bool): Whether to load ImageNet-pretrained model. Defaults to True.
"""
model = resnet50(weights=(ResNet50_Weights.IMAGENET1K_V1 if pretrained else None))
super().__init__(model=model, in_channels=in_channels, pretrained=pretrained)
[docs]
class ResNetFPNFeatureExtractor(nn.Module):
"""ResNet-based image feature extractor with FPN block.
The code is adopted from the repository: https://github.com/jac99/MinkLocMultimodal, MIT License
"""
def __init__(
self,
model: nn.Module,
layers: tuple[int, int, int, int, int],
in_channels: int = 3,
lateral_dim: int = 256,
fh_num_bottom_up: int = 4,
fh_num_top_down: int = 0,
pretrained: bool = True,
) -> None:
"""ResNet-based image feature extractor with FPN block.
Args:
model (nn.Module): ResNet model to use as feature extractor.
layers (tuple[int, int, int, int, int]): Number of channels in each layer of the ResNet model.
in_channels (int): Number of input channels. Defaults to 3.
lateral_dim (int): Output dimension for lateral connections. Defaults to 256.
fh_num_bottom_up (int): Number of bottom-up steps. Defaults to 4.
fh_num_top_down (int): Number of top-down steps. Defaults to 0.
pretrained (bool): Whether to load ImageNet-pretrained model. Defaults to True.
Raises:
ValueError: If `in_channels` is not 3 and `pretrained` is True.
"""
super().__init__()
if not (0 < fh_num_bottom_up <= 5):
raise ValueError("Number of bottom-up steps must be in range [1, 5]")
if not (0 <= fh_num_top_down < fh_num_bottom_up):
raise ValueError("Number of top-down steps must be in range [0, fh_num_bottom_up)")
if in_channels != 3 and pretrained:
raise ValueError("Pretrained models are only available for 3-channel images")
self.lateral_dim = lateral_dim
self.fh_num_bottom_up = fh_num_bottom_up
self.fh_num_top_down = fh_num_top_down
# Last 2 blocks are AdaptiveAvgPool2d and Linear (get rid of them)
self.resnet_fe = nn.ModuleList(list(model.children())[: 3 + self.fh_num_bottom_up])
# change input conv to accept n-channel images
if in_channels != 3:
self.resnet_fe[0] = nn.Conv2d(
in_channels=in_channels,
out_channels=self.resnet_fe[0].out_channels,
kernel_size=self.resnet_fe[0].kernel_size,
stride=self.resnet_fe[0].stride,
padding=self.resnet_fe[0].padding,
dilation=self.resnet_fe[0].dilation,
groups=self.resnet_fe[0].groups,
bias=self.resnet_fe[0].bias,
padding_mode=self.resnet_fe[0].padding_mode,
device=next(self.resnet_fe[0].parameters()).device,
dtype=next(self.resnet_fe[0].parameters()).dtype,
)
# Lateral connections and top-down pass for the feature extraction head
self.fh_tconvs = nn.ModuleDict() # Top-down transposed convolutions in feature head
self.fh_conv1x1 = nn.ModuleDict() # 1x1 convolutions in lateral connections to the feature head
for i in range(self.fh_num_bottom_up - self.fh_num_top_down, self.fh_num_bottom_up):
self.fh_conv1x1[str(i + 1)] = nn.Conv2d(
in_channels=layers[i], out_channels=self.lateral_dim, kernel_size=1
)
self.fh_tconvs[str(i + 1)] = nn.ConvTranspose2d(
in_channels=self.lateral_dim, out_channels=self.lateral_dim, kernel_size=2, stride=2
)
# One more lateral connection
temp = self.fh_num_bottom_up - self.fh_num_top_down
self.fh_conv1x1[str(temp)] = nn.Conv2d(
in_channels=layers[temp - 1], out_channels=self.lateral_dim, kernel_size=1
)
[docs]
def forward(self, image: Tensor) -> Tensor: # noqa: D102
x = image
feature_maps = {}
# 0, 1, 2, 3 = first layers: Conv2d, BatchNorm, ReLu, MaxPool2d
x = self.resnet_fe[0](x)
x = self.resnet_fe[1](x)
x = self.resnet_fe[2](x)
x = self.resnet_fe[3](x)
feature_maps["1"] = x
# sequential blocks, build from BasicBlock or Bottleneck blocks
for i in range(4, self.fh_num_bottom_up + 3):
x = self.resnet_fe[i](x)
feature_maps[str(i - 2)] = x
if len(feature_maps) != self.fh_num_bottom_up:
raise RuntimeError(
f"Number of feature maps ({len(feature_maps)}) does not match fh_num_bottom_up"
)
# x is (batch_size, 512, H=20, W=15) for 640x480 input image
# FEATURE HEAD TOP-DOWN PASS
xf = self.fh_conv1x1[str(self.fh_num_bottom_up)](feature_maps[str(self.fh_num_bottom_up)])
for i in range(self.fh_num_bottom_up, self.fh_num_bottom_up - self.fh_num_top_down, -1):
xf = self.fh_tconvs[str(i)](xf) # Upsample using transposed convolution
xf = xf + self.fh_conv1x1[str(i - 1)](feature_maps[str(i - 1)])
return xf
[docs]
class ResNet18FPNFeatureExtractor(ResNetFPNFeatureExtractor):
"""ResNet18 image feature extractor with FPN block.
The code is adopted from the repository: https://github.com/jac99/MinkLocMultimodal, MIT License
"""
def __init__(
self,
in_channels: int = 3,
lateral_dim: int = 256,
fh_num_bottom_up: int = 4,
fh_num_top_down: int = 0,
pretrained: bool = True,
) -> None:
"""ResNet18 image feature extractor with FPN block.
Args:
in_channels (int): Number of input channels. Defaults to 3.
lateral_dim (int): Output dimension for lateral connections. Defaults to 256.
fh_num_bottom_up (int): Number of bottom-up steps. Defaults to 4.
fh_num_top_down (int): Number of top-down steps. Defaults to 0.
pretrained (bool): Whether to load ImageNet-pretrained model. Defaults to True.
"""
# Number of channels in each layer of ResNet18
layers = (64, 64, 128, 256, 512)
model = resnet18(weights=(ResNet18_Weights.IMAGENET1K_V1 if pretrained else None))
super().__init__(
model=model,
layers=layers,
in_channels=in_channels,
lateral_dim=lateral_dim,
fh_num_bottom_up=fh_num_bottom_up,
fh_num_top_down=fh_num_top_down,
pretrained=pretrained,
)
[docs]
class ResNet50FPNFeatureExtractor(ResNetFPNFeatureExtractor):
"""ResNet50 image feature extractor with FPN block.
The code is adopted from the repository: https://github.com/jac99/MinkLocMultimodal, MIT License
"""
def __init__(
self,
in_channels: int = 3,
lateral_dim: int = 256,
fh_num_bottom_up: int = 4,
fh_num_top_down: int = 0,
pretrained: bool = True,
) -> None:
"""ResNet50 image feature extractor with FPN block.
Args:
in_channels (int): Number of input channels. Defaults to 3.
lateral_dim (int): Output dimension for lateral connections. Defaults to 256.
fh_num_bottom_up (int): Number of bottom-up steps. Defaults to 4.
fh_num_top_down (int): Number of top-down steps. Defaults to 0.
pretrained (bool): Whether to load ImageNet-pretrained model. Defaults to True.
"""
# Number of channels in each layer of ResNet18
layers = (64, 256, 512, 1024, 2048)
model = resnet50(weights=(ResNet50_Weights.IMAGENET1K_V1 if pretrained else None))
super().__init__(
model=model,
layers=layers,
in_channels=in_channels,
lateral_dim=lateral_dim,
fh_num_bottom_up=fh_num_bottom_up,
fh_num_top_down=fh_num_top_down,
pretrained=pretrained,
)