Source code for opr.modules.cosplace

"""CosPlace aggregation layer implementation."""
from torch import Tensor, nn
from torch.nn import functional as F

from .gem import GeM


[docs] class CosPlace(nn.Module): """CosPlace aggregation layer. As implemented in https://github.com/gmberton/CosPlace/blob/main/model/network.py """ def __init__(self, in_dim: int, out_dim: int) -> None: """Aggregation layer as implemented in CosPlace method. Args: in_dim (int): Input dimension. out_dim (int): Output dimension. """ super().__init__() self.gem = GeM() self.fc = nn.Linear(in_dim, out_dim)
[docs] def forward(self, x: Tensor) -> Tensor: # noqa: D102 x = F.normalize(x, p=2, dim=1) x = self.gem(x) x = x.flatten(1) x = self.fc(x) x = F.normalize(x, p=2, dim=1) return x