Source code for opr.modules.feature_extractors.svtnet

"""Implementation of feature extraction model from SVT-Net.

Citation:
    Fan, Zhaoxin, et al. "Svt-net: Super light-weight sparse voxel transformer
    for large scale place recognition." Proceedings of the AAAI Conference on Artificial Intelligence.
    Vol. 36. No. 1. 2022.

Source: https://github.com/ZhenboSong/SVTNet
Paper: https://arxiv.org/abs/2105.00149
"""
from __future__ import annotations

from loguru import logger
from torch import nn

from opr.modules.eca import MinkECABasicBlock as ECABasicBlock
from opr.modules.feature_extractors.mink_resnet import MinkResNetBase
from opr.modules.svt import ASVT, CSVT

try:
    import MinkowskiEngine as ME  # type: ignore
    from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck

    minkowski_available = True
except ImportError:
    logger.warning("MinkowskiEngine is not installed. Some features may not be available.")
    BasicBlock = Bottleneck = nn.Module
    minkowski_available = False


[docs] class SVTNetFeatureExtractor(MinkResNetBase): """Feature extraction model from SVT-Net. Source: https://github.com/ZhenboSong/SVTNet """ sparse = True def __init__( self, in_channels: int = 1, out_channels: int = 256, conv0_kernel_size: int = 5, block: str = "ECABasicBlock", asvt: bool = True, csvt: bool = True, layers: tuple[int, ...] = (1, 1, 1), planes: tuple[int, ...] = (32, 64, 64), ) -> None: """Feature extraction model from SVT-Net. Args: in_channels (int): Number of input channels. Defaults to 1. out_channels (int): Number of output channels. Defaults to 256. conv0_kernel_size (int): Kernel size of the first convolution. Defaults to 5. block (str): Type of the network block. Defaults to "ECABasicBlock". asvt (bool): Whether to use ASVT. Defaults to True. csvt (bool): Whether to use CSVT. Defaults to True. layers (tuple[int, ...]): Number of blocks in each layer. Defaults to (1, 1, 1). planes (tuple[int, ...]): Number of channels in each layer. Defaults to (32, 64, 64). Raises: RuntimeError: If MinkowskiEngine is not installed. ValueError: If the number of layers and planes is not the same. ValueError: If the number of layers is less than 1. """ if not minkowski_available: raise RuntimeError( "MinkowskiEngine is not installed. SVTNetFeatureExtractor requires MinkowskiEngine." ) if not len(layers) == len(planes): raise ValueError("The number of layers and planes must be the same.") if not 1 <= len(layers): raise ValueError("The number of layers must be at least 1.") self.num_bottom_up = len(layers) self.conv0_kernel_size = conv0_kernel_size self.block = self._create_resnet_block(block_name=block) self.layers = layers self.planes = planes self.lateral_dim = out_channels self.init_dim = planes[0] self.is_asvt = asvt self.is_csvt = csvt MinkResNetBase.__init__(self, in_channels, out_channels, dimension=3) def _create_resnet_block(self, block_name: str) -> nn.Module: if block_name == "BasicBlock": block_module = BasicBlock elif block_name == "Bottleneck": block_module = Bottleneck elif block_name == "ECABasicBlock": block_module = ECABasicBlock else: raise NotImplementedError(f"Unsupported network block: {block_name}") return block_module def _network_initialization(self, in_channels: int, out_channels: int, dimension: int) -> None: self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 self.bn = nn.ModuleList() # Bottom-up BatchNorms self.blocks = nn.ModuleList() # Bottom-up blocks self.tconvs = nn.ModuleList() # Top-down tranposed convolutions self.conv1x1 = nn.ModuleList() # 1x1 convolutions in lateral connections # The first convolution is special case, with kernel size = 5 self.inplanes = self.planes[0] self.conv0 = ME.MinkowskiConvolution( in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, dimension=dimension ) self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) for plane, layer in zip(self.planes, self.layers): self.convs.append( ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=dimension ) ) self.bn.append(ME.MinkowskiBatchNorm(self.inplanes)) self.blocks.append(self._make_layer(self.block, plane, layer)) self.conv1x1.append( ME.MinkowskiConvolution( self.inplanes, self.lateral_dim, kernel_size=1, stride=1, dimension=dimension ) ) # before_lateral_dim=plane after_reduction = max(self.lateral_dim / 8, 8) reduction = int(self.lateral_dim // after_reduction) if self.is_asvt: self.asvt = ASVT(self.lateral_dim, reduction) if self.is_csvt: self.csvt = CSVT(self.lateral_dim, 8) self.relu = ME.MinkowskiReLU(inplace=True)
[docs] def forward(self, x: ME.SparseTensor) -> ME.SparseTensor: # noqa: D102 # *** BOTTOM-UP PASS *** # First bottom-up convolution is special (with bigger stride) x = self.conv0(x) x = self.bn0(x) x = self.relu(x) # BOTTOM-UP PASS for _, (conv, bn, block) in enumerate(zip(self.convs, self.bn, self.blocks)): x = conv(x) # Decreases spatial resolution (conv stride=2) x = bn(x) x = self.relu(x) x = block(x) x = self.conv1x1[0](x) if self.is_csvt: x_csvt = self.csvt(x) if self.is_asvt: x_asvt = self.asvt(x) if self.is_csvt and self.is_asvt: x = x_csvt + x_asvt elif self.is_csvt: x = x_csvt elif self.is_asvt: x = x_asvt return x