Source code for opr.models.place_recognition.pointnetvlad

"""Implementation of PointNetVLAD model."""
from __future__ import print_function

import math

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Variable


[docs] class NetVLADLoupe(nn.Module): """NetVLAD aggregation layer with gating mechanism.""" def __init__( self, feature_size: int, max_samples: int, cluster_size: int, output_dim: int, gating: bool = True, add_batch_norm: bool = True, is_training: bool = True, ) -> None: """Initialize NetVLADLoupe layer.""" super().__init__() self.feature_size = feature_size self.max_samples = max_samples self.output_dim = output_dim self.is_training = is_training self.gating = gating self.add_batch_norm = add_batch_norm self.cluster_size = cluster_size self.softmax = nn.Softmax(dim=-1) self.cluster_weights = nn.Parameter( torch.randn(feature_size, cluster_size) * 1 / math.sqrt(feature_size) ) self.cluster_weights2 = nn.Parameter( torch.randn(1, feature_size, cluster_size) * 1 / math.sqrt(feature_size) ) self.hidden1_weights = nn.Parameter( torch.randn(cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size) ) if add_batch_norm: self.cluster_biases = None self.bn1 = nn.BatchNorm1d(cluster_size) else: self.cluster_biases = nn.Parameter(torch.randn(cluster_size) * 1 / math.sqrt(feature_size)) self.bn1 = None self.bn2 = nn.BatchNorm1d(output_dim) if gating: self.context_gating = GatingContext(output_dim, add_batch_norm=add_batch_norm)
[docs] def forward(self, x: Tensor) -> Tensor: # noqa: D102 x = x.transpose(1, 3).contiguous() x = x.view((-1, self.max_samples, self.feature_size)) activation = torch.matmul(x, self.cluster_weights) if self.add_batch_norm: # activation = activation.transpose(1,2).contiguous() activation = activation.view(-1, self.cluster_size) activation = self.bn1(activation) activation = activation.view(-1, self.max_samples, self.cluster_size) # activation = activation.transpose(1,2).contiguous() else: activation = activation + self.cluster_biases activation = self.softmax(activation) activation = activation.view((-1, self.max_samples, self.cluster_size)) a_sum = activation.sum(-2, keepdim=True) a = a_sum * self.cluster_weights2 activation = torch.transpose(activation, 2, 1) x = x.view((-1, self.max_samples, self.feature_size)) vlad = torch.matmul(activation, x) vlad = torch.transpose(vlad, 2, 1) vlad = vlad - a vlad = F.normalize(vlad, dim=1, p=2) vlad = vlad.contiguous().view((-1, self.cluster_size * self.feature_size)) vlad = F.normalize(vlad, dim=1, p=2) vlad = torch.matmul(vlad, self.hidden1_weights) vlad = self.bn2(vlad) if self.gating: vlad = self.context_gating(vlad) return vlad
[docs] class GatingContext(nn.Module): """Gating context layer.""" def __init__(self, dim: int, add_batch_norm: bool = True) -> None: """Initialize GatingContext layer.""" super().__init__() self.dim = dim self.add_batch_norm = add_batch_norm self.gating_weights = nn.Parameter(torch.randn(dim, dim) * 1 / math.sqrt(dim)) self.sigmoid = nn.Sigmoid() if add_batch_norm: self.gating_biases = None self.bn1 = nn.BatchNorm1d(dim) else: self.gating_biases = nn.Parameter(torch.randn(dim) * 1 / math.sqrt(dim)) self.bn1 = None
[docs] def forward(self, x: Tensor) -> Tensor: # noqa: D102 gates = torch.matmul(x, self.gating_weights) if self.add_batch_norm: gates = self.bn1(gates) else: gates = gates + self.gating_biases gates = self.sigmoid(gates) activation = x * gates return activation
[docs] class Flatten(nn.Module): """Flatten layer.""" def __init__(self) -> None: """Initialize Flatten layer.""" super().__init__(self)
[docs] def forward(self, input: Tensor) -> Tensor: # noqa: D102 return input.view(input.size(0), -1)
[docs] class STN3d(nn.Module): """Spatial Transformer Network for 3D data.""" def __init__(self, num_points: int = 2500, k: int = 3, use_bn: bool = True) -> None: """Initialize STN3d.""" super().__init__() self.k = k self.kernel_size = 3 if k == 3 else 1 self.channels = 1 if k == 3 else k self.num_points = num_points self.use_bn = use_bn self.conv1 = torch.nn.Conv2d(self.channels, 64, (1, self.kernel_size)) self.conv2 = torch.nn.Conv2d(64, 128, (1, 1)) self.conv3 = torch.nn.Conv2d(128, 1024, (1, 1)) self.mp1 = torch.nn.MaxPool2d((num_points, 1), 1) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, k * k) self.fc3.weight.data.zero_() self.fc3.bias.data.zero_() self.relu = nn.ReLU() if use_bn: self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(128) self.bn3 = nn.BatchNorm2d(1024) self.bn4 = nn.BatchNorm1d(512) self.bn5 = nn.BatchNorm1d(256)
[docs] def forward(self, x: Tensor) -> Tensor: # noqa: D102 batchsize = x.size()[0] if self.use_bn: x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) else: x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = self.mp1(x) x = x.view(-1, 1024) if self.use_bn: x = F.relu(self.bn4(self.fc1(x))) x = F.relu(self.bn5(self.fc2(x))) else: x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) iden = ( Variable(torch.from_numpy(np.eye(self.k).astype(np.float32))) .view(1, self.k * self.k) .repeat(batchsize, 1) ) if x.is_cuda: iden = iden.cuda() x = x + iden x = x.view(-1, self.k, self.k) return x
[docs] class PointNetFeat(nn.Module): """PointNet feature extractor.""" def __init__( self, num_points: int = 2500, global_feat: bool = True, feature_transform: bool = False, max_pool: bool = True, ) -> None: """Initialize PointNetFeat.""" super().__init__() self.stn = STN3d(num_points=num_points, k=3, use_bn=False) self.feature_trans = STN3d(num_points=num_points, k=64, use_bn=False) self.apply_feature_trans = feature_transform self.conv1 = torch.nn.Conv2d(1, 64, (1, 3)) self.conv2 = torch.nn.Conv2d(64, 64, (1, 1)) self.conv3 = torch.nn.Conv2d(64, 64, (1, 1)) self.conv4 = torch.nn.Conv2d(64, 128, (1, 1)) self.conv5 = torch.nn.Conv2d(128, 1024, (1, 1)) self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) self.bn3 = nn.BatchNorm2d(64) self.bn4 = nn.BatchNorm2d(128) self.bn5 = nn.BatchNorm2d(1024) self.mp1 = torch.nn.MaxPool2d((num_points, 1), 1) self.num_points = num_points self.global_feat = global_feat self.max_pool = max_pool
[docs] def forward(self, x: Tensor) -> Tensor: # noqa: D102 batchsize = x.size()[0] trans = self.stn(x) x = torch.matmul(torch.squeeze(x), trans) x = x.view(batchsize, 1, -1, 3) # x = x.transpose(2,1) # x = torch.bmm(x, trans) # x = x.transpose(2,1) x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) pointfeat = x if self.apply_feature_trans: f_trans = self.feature_trans(x) x = torch.squeeze(x) if batchsize == 1: x = torch.unsqueeze(x, 0) x = torch.matmul(x.transpose(1, 2), f_trans) x = x.transpose(1, 2).contiguous() x = x.view(batchsize, 64, -1, 1) x = F.relu(self.bn3(self.conv3(x))) x = F.relu(self.bn4(self.conv4(x))) x = self.bn5(self.conv5(x)) if not self.max_pool: return x else: x = self.mp1(x) x = x.view(-1, 1024) if self.global_feat: return x # , trans else: x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points) return torch.cat([x, pointfeat], 1) # , trans
[docs] class PointNetVLAD(nn.Module): """PointNetVLAD: Deep Point Cloud Based Retrieval for Large-Scale Place Recognition. Paper: https://arxiv.org/abs/1804.03492 Original repository: https://github.com/mikacuy/pointnetvlad Code is adopted from repository: https://github.com/cattaneod/PointNetVlad-Pytorch """ def __init__( self, num_points: int = 2500, global_feat: bool = True, feature_transform: bool = False, max_pool: bool = False, output_dim: int = 1024, ) -> None: """Initialize PointNetVLAD model. Args: num_points (int): Number of points in the input point cloud. Defaults to 2500. global_feat (bool): Whether to use global feature or not. Defaults to True. feature_transform (bool): Whether to apply feature transform or not. Defaults to False. max_pool (bool): Whether to use max pooling or not. Defaults to False. output_dim (int): Output dimension of the model. Defaults to 1024. """ super().__init__() self.point_net = PointNetFeat( num_points=num_points, global_feat=global_feat, feature_transform=feature_transform, max_pool=max_pool, ) self.net_vlad = NetVLADLoupe( feature_size=1024, max_samples=num_points, cluster_size=64, output_dim=output_dim, gating=True, add_batch_norm=True, is_training=True, )
[docs] def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # noqa: D102 points = batch["pointclouds_lidar_coords"] x = self.point_net(points) x = self.net_vlad(x) out_dict: dict[str, Tensor] = {"final_descriptor": x} return out_dict