Source code for opr.models.place_recognition.apgem

"""Implementation of APGeM Image Model."""
from typing import Literal

from opr.modules import GeM
from opr.modules.feature_extractors import (
    ResNet18FPNFeatureExtractor,
    ResNet50FPNFeatureExtractor,
    VGG16FeatureExtractor,
)

from .base import ImageModel


[docs] class APGeMModel(ImageModel): """APGeM: 'Learning with Average Precision: Training Image Retrieval with a Listwise Loss'. Paper: https://arxiv.org/abs/1906.07589 """ def __init__(self, backbone: Literal["resnet18", "resnet50", "vgg16"] = "resnet50") -> None: """Initialize APGeM Image Model. Args: backbone (str): Backbone architecture. Defaults to "resnet50". Raises: NotImplementedError: If given backbone is unknown. """ if backbone == "resnet18": backbone = ResNet18FPNFeatureExtractor() elif backbone == "resnet50": backbone = ResNet50FPNFeatureExtractor() elif backbone == "vgg16": backbone = VGG16FeatureExtractor() else: raise NotImplementedError(f"Backbone {backbone} is not supported.") head = GeM() super().__init__( backbone=backbone, head=head, )