Source code for opr.models.place_recognition.svtnet

"""SVT-Net: Super Light-Weight Sparse Voxel Transformer for Large Scale Place Recognition.

Citation:
    Fan, Zhaoxin, et al. "Svt-net: Super light-weight sparse voxel transformer
    for large scale place recognition." Proceedings of the AAAI Conference on Artificial Intelligence.
    Vol. 36. No. 1. 2022.

Source: https://github.com/ZhenboSong/SVTNet
Paper: https://arxiv.org/abs/2105.00149
"""
from opr.modules import MinkGeM
from opr.modules.feature_extractors import SVTNetFeatureExtractor

from .base import CloudModel


[docs] class SVTNet(CloudModel): """SVT-Net: Super Light-Weight Sparse Voxel Transformer for Large Scale Place Recognition. Citation: Fan, Zhaoxin, et al. "Svt-net: Super light-weight sparse voxel transformer for large scale place recognition." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 36. No. 1. 2022. Source: https://github.com/ZhenboSong/SVTNet Paper: https://arxiv.org/abs/2105.00149 """ def __init__( self, in_channels: int = 1, out_channels: int = 256, conv0_kernel_size: int = 5, block: str = "ECABasicBlock", asvt: bool = True, csvt: bool = True, layers: tuple[int, ...] = (1, 1, 1), planes: tuple[int, ...] = (32, 64, 64), pooling: str = "gem", ) -> None: """SVT-Net: Super Light-Weight Sparse Voxel Transformer for Large Scale Place Recognition. Args: in_channels (int): Number of input channels. Defaults to 1. out_channels (int): Number of output channels. Defaults to 256. conv0_kernel_size (int): Kernel size of the first convolution. Defaults to 5. block (str): Type of the network block. Defaults to "ECABasicBlock". asvt (bool): Whether to use ASVT. Defaults to True. csvt (bool): Whether to use CSVT. Defaults to True. layers (tuple[int, ...]): Number of blocks in each layer. Defaults to (1, 1, 1). planes (tuple[int, ...]): Number of channels in each layer. Defaults to (32, 64, 64). pooling (str): Type of pooling. Defaults to "gem". Raises: NotImplementedError: If given pooling method is unknown. """ feature_extractor = SVTNetFeatureExtractor( in_channels, out_channels, conv0_kernel_size, block, asvt, csvt, layers, planes ) if pooling == "gem": pooling_head = MinkGeM() else: raise NotImplementedError("Unknown pooling method: {}".format(pooling)) super().__init__( backbone=feature_extractor, head=pooling_head, )