Source code for opr.models.place_recognition.netvlad

"""Implementation of NetVLAD model."""
from typing import Literal

from opr.modules.feature_extractors import (
    ResNet18FPNFeatureExtractor,
    ResNet50FPNFeatureExtractor,
    VGG16FeatureExtractor,
)
from opr.modules.netvlad import NetVLAD

from .base import ImageModel


[docs] class NetVLADModel(ImageModel): """NetVLAD: CNN architecture for weakly supervised place recognition. Paper: https://arxiv.org/abs/1511.07247v3 Code is adopted from the repository: https://github.com/Nanne/pytorch-NetVlad """ def __init__( self, backbone: Literal["resnet18", "resnet50", "vgg16"] = "resnet18", num_clusters: int = 64, normalize_input: bool = True, vladv2: bool = False, ) -> None: """Initialize NetVLAD Image Model. Args: backbone (str): Backbone architecture. Defaults to "resnet18". num_clusters (int): Number of VLAD clusters. Defaults to 64. normalize_input (bool): Whether to normalize input data or not. Defaults to True. vladv2 (bool): Use vladv2 init params method. Defaults to False. Raises: NotImplementedError: If given backbone is unknown. """ if backbone == "resnet18": backbone = ResNet18FPNFeatureExtractor() dim = 256 elif backbone == "resnet50": backbone = ResNet50FPNFeatureExtractor() dim = 256 elif backbone == "vgg16": backbone = VGG16FeatureExtractor() dim = 512 else: raise NotImplementedError(f"Backbone {backbone} is not supported.") head = NetVLAD(num_clusters, dim, normalize_input, vladv2) super().__init__( backbone=backbone, head=head, )