Source code for opr.models.place_recognition.overlaptransformer

"""Implementation of OverlapTransformer model."""
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from opr.modules import NetVLAD


[docs] class OverlapTransformer(nn.Module): """OverlapTransformer: An Efficient and Yaw-Angle-Invariant Transformer Network for LiDAR-Based Place Recognition. Paper: https://arxiv.org/abs/2203.03397 Adapted from original repository: https://github.com/haomo-ai/OverlapTransformer """ def __init__( self, height: int = 64, width: int = 900, channels: int = 1, norm_layer: nn.Module = None, use_transformer: bool = True, ) -> None: """Initialize the OverlapTransformer model. Args: height (int): Height of the input tensor. Defaults to 64. width (int): Width of the input tensor. Defaults to 900. channels (int): Number of channels in the input tensor. Defaults to 1. norm_layer (nn.Module): Normalization layer to use. Defaults to None. use_transformer (bool): Whether to use the transformer encoder. Defaults to True. """ super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self.use_transformer = use_transformer self.conv1 = nn.Conv2d(channels, 16, kernel_size=(5, 1), stride=(1, 1), bias=False) self.bn1 = norm_layer(16) self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 1), stride=(2, 1), bias=False) self.bn2 = norm_layer(32) self.conv3 = nn.Conv2d(32, 64, kernel_size=(3, 1), stride=(2, 1), bias=False) self.bn3 = norm_layer(64) self.conv4 = nn.Conv2d(64, 64, kernel_size=(3, 1), stride=(2, 1), bias=False) self.bn4 = norm_layer(64) self.conv5 = nn.Conv2d(64, 128, kernel_size=(2, 1), stride=(2, 1), bias=False) self.bn5 = norm_layer(128) self.conv6 = nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 1), bias=False) self.bn6 = norm_layer(128) self.conv7 = nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 1), bias=False) self.bn7 = norm_layer(128) self.conv8 = nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 1), bias=False) self.bn8 = norm_layer(128) self.conv9 = nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 1), bias=False) self.bn9 = norm_layer(128) self.conv10 = nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 1), bias=False) self.bn10 = norm_layer(128) self.conv11 = nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 1), bias=False) self.bn11 = norm_layer(128) self.relu = nn.ReLU(inplace=True) """ MHSA num_layers=1 is suggested in our work. """ encoder_layer = nn.TransformerEncoderLayer( d_model=256, nhead=4, dim_feedforward=1024, activation="relu", batch_first=False, dropout=0.0 ) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.convLast1 = nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) self.bnLast1 = norm_layer(256) self.convLast2 = nn.Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) self.bnLast2 = norm_layer(1024) self.linear = nn.Linear(128 * 900, 256) self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax() self.net_vlad = NetVLAD(num_clusters=64, dim=1024) # TODO: implement with 'NetVLADLoupe'? self.linear1 = nn.Linear(1 * 256, 256) self.bnl1 = norm_layer(256) self.linear2 = nn.Linear(1 * 256, 256) self.bnl2 = norm_layer(256) self.linear3 = nn.Linear(1 * 256, 256) self.bnl3 = norm_layer(256)
[docs] def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # noqa: D102 for key, value in batch.items(): if key.startswith("range_image"): x_l = value out_l = self.relu(self.conv1(x_l)) out_l = self.relu(self.conv2(out_l)) out_l = self.relu(self.conv3(out_l)) out_l = self.relu(self.conv4(out_l)) out_l = self.relu(self.conv5(out_l)) out_l = self.relu(self.conv6(out_l)) out_l = self.relu(self.conv7(out_l)) out_l = self.relu(self.conv8(out_l)) out_l = self.relu(self.conv9(out_l)) out_l = self.relu(self.conv10(out_l)) out_l = self.relu(self.conv11(out_l)) out_l_1 = out_l.permute(0, 1, 3, 2) out_l_1 = self.relu(self.convLast1(out_l_1)) # Using transformer needs to decide whether batch_size first if self.use_transformer: out_l = out_l_1.squeeze(3) out_l = out_l.permute(2, 0, 1) out_l = self.transformer_encoder(out_l) out_l = out_l.permute(1, 2, 0) out_l = out_l.unsqueeze(3) out_l = torch.cat((out_l_1, out_l), dim=1) out_l = self.relu(self.convLast2(out_l)) out_l = F.normalize(out_l, dim=1) out_l = self.net_vlad(out_l) out_l = F.normalize(out_l, dim=1) else: out_l = torch.cat((out_l_1, out_l_1), dim=1) out_l = F.normalize(out_l, dim=1) out_l = self.net_vlad(out_l) out_l = F.normalize(out_l, dim=1) out_dict: dict[str, Tensor] = {"final_descriptor": out_l} return out_dict