Source code for opr.models.registration.geotransformer

"""GeoTransformer model for registration.

Paper: https://arxiv.org/abs/2202.06688

Code is adopted from original repository: https://github.com/qinzheng93/GeoTransformer, MIT License
"""
from time import time
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F
from omegaconf import DictConfig
from torch import Tensor, nn

try:
    from geotransformer.modules.geotransformer import (
        GeometricTransformer,
        LocalGlobalRegistration,
        SuperPointMatching,
        SuperPointTargetGenerator,
    )
    from geotransformer.modules.kpconv import (
        ConvBlock,
        LastUnaryBlock,
        ResidualBlock,
        UnaryBlock,
        nearest_upsample,
    )
    from geotransformer.modules.ops import index_select, point_to_node_partition
    from geotransformer.modules.registration import get_node_correspondences
    from geotransformer.modules.sinkhorn import LearnableLogOptimalTransport
    from geotransformer.utils.data import (
        calibrate_neighbors_stack_mode,
        registration_collate_fn_stack_mode,
    )
    from geotransformer.utils.torch import to_cuda
except ImportError as err:
    raise ImportError(
        "To use the GeoTransformer model, please install the geotransformer package first."
    ) from err


[docs] class KPConvFPN(nn.Module): """Feature Pyramid Network with KPConv backbone.""" def __init__( self, input_dim: int, output_dim: int, init_dim: int, kernel_size: int, init_radius: float, init_sigma: float, group_norm: int, ) -> None: """Feature Pyramid Network with KPConv backbone. Args: input_dim: The input feature dimension. output_dim: The output feature dimension. init_dim: The initial feature dimension. kernel_size: The kernel size of KPConv. init_radius: The initial radius of KPConv. init_sigma: The initial sigma of KPConv. group_norm: The number of groups in group normalization. """ super().__init__() self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm) self.encoder1_2 = ResidualBlock( init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm ) self.encoder2_1 = ResidualBlock( init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True ) self.encoder2_2 = ResidualBlock( init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm ) self.encoder2_3 = ResidualBlock( init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm ) self.encoder3_1 = ResidualBlock( init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True, ) self.encoder3_2 = ResidualBlock( init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm ) self.encoder3_3 = ResidualBlock( init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm ) self.encoder4_1 = ResidualBlock( init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm, strided=True, ) self.encoder4_2 = ResidualBlock( init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm ) self.encoder4_3 = ResidualBlock( init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm ) self.encoder5_1 = ResidualBlock( init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm, strided=True, ) self.encoder5_2 = ResidualBlock( init_dim * 16, init_dim * 32, kernel_size, init_radius * 16, init_sigma * 16, group_norm ) self.encoder5_3 = ResidualBlock( init_dim * 32, init_dim * 32, kernel_size, init_radius * 16, init_sigma * 16, group_norm ) self.decoder4 = UnaryBlock(init_dim * 48, init_dim * 16, group_norm) self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm) self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim)
[docs] def forward(self, feats: Tensor, data_dict: Dict[str, List]) -> List[Tensor]: # noqa: D102 feats_list = [] points_list = data_dict["points"] neighbors_list = data_dict["neighbors"] subsampling_list = data_dict["subsampling"] upsampling_list = data_dict["upsampling"] feats_s1 = feats feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0]) feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0]) feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0]) feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1]) feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1]) feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1]) feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2]) feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2]) feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2]) feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3]) feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3]) feats_s5 = self.encoder5_1(feats_s4, points_list[4], points_list[3], subsampling_list[3]) feats_s5 = self.encoder5_2(feats_s5, points_list[4], points_list[4], neighbors_list[4]) feats_s5 = self.encoder5_3(feats_s5, points_list[4], points_list[4], neighbors_list[4]) latent_s5 = feats_s5 feats_list.append(feats_s5) latent_s4 = nearest_upsample(latent_s5, upsampling_list[3]) latent_s4 = torch.cat([latent_s4, feats_s4], dim=1) latent_s4 = self.decoder4(latent_s4) feats_list.append(latent_s4) latent_s3 = nearest_upsample(latent_s4, upsampling_list[2]) latent_s3 = torch.cat([latent_s3, feats_s3], dim=1) latent_s3 = self.decoder3(latent_s3) feats_list.append(latent_s3) latent_s2 = nearest_upsample(latent_s3, upsampling_list[1]) latent_s2 = torch.cat([latent_s2, feats_s2], dim=1) latent_s2 = self.decoder2(latent_s2) feats_list.append(latent_s2) feats_list.reverse() return feats_list
[docs] class GeoTransformer(nn.Module): """GeoTransformer model for registration.""" def __init__( self, model: DictConfig, backbone: DictConfig, geotransformer: DictConfig, coarse_matching: DictConfig, fine_matching: DictConfig, ) -> None: """Geotransformer model for registration. Args: model: The model configuration. backbone: The backbone configuration. geotransformer: The geotransformer configuration. coarse_matching: The coarse matching configuration. fine_matching: The fine matching configuration. """ super().__init__() self.num_points_in_patch = model.num_points_in_patch self.matching_radius = model.ground_truth_matching_radius backbone.init_radius = backbone.base_radius * backbone.init_voxel_size backbone.init_sigma = backbone.base_sigma * backbone.init_voxel_size self.backbone_cfg = backbone self.backbone = KPConvFPN( backbone.input_dim, backbone.output_dim, backbone.init_dim, backbone.kernel_size, backbone.init_radius, backbone.init_sigma, backbone.group_norm, ) self.transformer = GeometricTransformer( geotransformer.input_dim, geotransformer.output_dim, geotransformer.hidden_dim, geotransformer.num_heads, geotransformer.blocks, geotransformer.sigma_d, geotransformer.sigma_a, geotransformer.angle_k, reduction_a=geotransformer.reduction_a, ) self.coarse_target = SuperPointTargetGenerator( coarse_matching.num_targets, coarse_matching.overlap_threshold ) self.coarse_matching = SuperPointMatching( coarse_matching.num_correspondences, coarse_matching.dual_normalization ) self.fine_matching = LocalGlobalRegistration( fine_matching.topk, fine_matching.acceptance_radius, mutual=fine_matching.mutual, confidence_threshold=fine_matching.confidence_threshold, use_dustbin=fine_matching.use_dustbin, use_global_score=fine_matching.use_global_score, correspondence_threshold=fine_matching.correspondence_threshold, correspondence_limit=fine_matching.correspondence_limit, num_refinement_steps=fine_matching.num_refinement_steps, ) self.optimal_transport = LearnableLogOptimalTransport(model.num_sinkhorn_iterations) self.stats_history = { "preprocessing": [], "generate_gt": [], "encoder": [], "transformer": [], "coarse_matching": [], "optimal_transport": [], "fine_matching": [], } @property def _is_cuda(self) -> bool: for param in self.parameters(): if param.is_cuda: return True return False def _preprocess_input( self, query_pc: Tensor, db_pc: Tensor, gt_transform: Optional[Tensor] = None ) -> Dict[str, Any]: data_dict = {} data_dict["ref_points"] = db_pc.cpu() data_dict["src_points"] = query_pc.cpu() data_dict["ref_feats"] = torch.ones((query_pc.shape[0], 1), dtype=torch.float32) data_dict["src_feats"] = torch.ones((db_pc.shape[0], 1), dtype=torch.float32) if gt_transform: data_dict["transform"] = gt_transform else: data_dict["transform"] = torch.eye(4, dtype=torch.float32) neighbor_limits = calibrate_neighbors_stack_mode( [data_dict], registration_collate_fn_stack_mode, self.backbone_cfg.num_stages, self.backbone_cfg.init_voxel_size, self.backbone_cfg.init_radius, ) data_dict = registration_collate_fn_stack_mode( [data_dict], self.backbone_cfg.num_stages, self.backbone_cfg.init_voxel_size, self.backbone_cfg.init_radius, neighbor_limits, ) if self._is_cuda: data_dict = to_cuda(data_dict) return data_dict
[docs] def forward( # noqa: D102 self, query_pc: Tensor, db_pc: Tensor, gt_transform: Optional[Tensor] = None ) -> Dict[str, Any]: t_s = time() data_dict = self._preprocess_input(query_pc, db_pc, gt_transform) output_dict = {} # Downsample point clouds feats = data_dict["features"].detach() transform = data_dict["transform"].detach() ref_length_c = data_dict["lengths"][-1][0].item() ref_length_f = data_dict["lengths"][1][0].item() # ref_length = data_dict["lengths"][0][0].item() points_c = data_dict["points"][-1].detach() points_f = data_dict["points"][1].detach() # points = data_dict["points"][0].detach() ref_points_c = points_c[:ref_length_c] src_points_c = points_c[ref_length_c:] ref_points_f = points_f[:ref_length_f] src_points_f = points_f[ref_length_f:] # ref_points = points[:ref_length] # src_points = points[ref_length:] # output_dict["ref_points_c"] = ref_points_c # output_dict["src_points_c"] = src_points_c # output_dict["ref_points_f"] = ref_points_f # output_dict["src_points_f"] = src_points_f # output_dict["ref_points"] = ref_points # output_dict["src_points"] = src_points self.stats_history["preprocessing"].append(time() - t_s) # 1. Generate ground truth node correspondences t_s = time() _, ref_node_masks, ref_node_knn_indices, ref_node_knn_masks = point_to_node_partition( ref_points_f, ref_points_c, self.num_points_in_patch ) _, src_node_masks, src_node_knn_indices, src_node_knn_masks = point_to_node_partition( src_points_f, src_points_c, self.num_points_in_patch ) ref_padded_points_f = torch.cat([ref_points_f, torch.zeros_like(ref_points_f[:1])], dim=0) src_padded_points_f = torch.cat([src_points_f, torch.zeros_like(src_points_f[:1])], dim=0) ref_node_knn_points = index_select(ref_padded_points_f, ref_node_knn_indices, dim=0) src_node_knn_points = index_select(src_padded_points_f, src_node_knn_indices, dim=0) gt_node_corr_indices, gt_node_corr_overlaps = get_node_correspondences( ref_points_c, src_points_c, ref_node_knn_points, src_node_knn_points, transform, self.matching_radius, ref_masks=ref_node_masks, src_masks=src_node_masks, ref_knn_masks=ref_node_knn_masks, src_knn_masks=src_node_knn_masks, ) # output_dict["gt_node_corr_indices"] = gt_node_corr_indices # output_dict["gt_node_corr_overlaps"] = gt_node_corr_overlaps self.stats_history["generate_gt"].append(time() - t_s) # 2. KPFCNN Encoder t_s = time() feats_list = self.backbone(feats, data_dict) feats_c = feats_list[-1] feats_f = feats_list[0] self.stats_history["encoder"].append(time() - t_s) # 3. Conditional Transformer t_s = time() ref_feats_c = feats_c[:ref_length_c] src_feats_c = feats_c[ref_length_c:] ref_feats_c, src_feats_c = self.transformer( ref_points_c.unsqueeze(0), src_points_c.unsqueeze(0), ref_feats_c.unsqueeze(0), src_feats_c.unsqueeze(0), ) ref_feats_c_norm = F.normalize(ref_feats_c.squeeze(0), p=2, dim=1) src_feats_c_norm = F.normalize(src_feats_c.squeeze(0), p=2, dim=1) # output_dict["ref_feats_c"] = ref_feats_c_norm # output_dict["src_feats_c"] = src_feats_c_norm self.stats_history["transformer"].append(time() - t_s) # 5. Head for fine level matching ref_feats_f = feats_f[:ref_length_f] src_feats_f = feats_f[ref_length_f:] # output_dict["ref_feats_f"] = ref_feats_f # output_dict["src_feats_f"] = src_feats_f # 6. Select topk nearest node correspondences t_s = time() with torch.no_grad(): ref_node_corr_indices, src_node_corr_indices, node_corr_scores = self.coarse_matching( ref_feats_c_norm, src_feats_c_norm, ref_node_masks, src_node_masks ) # output_dict["ref_node_corr_indices"] = ref_node_corr_indices # output_dict["src_node_corr_indices"] = src_node_corr_indices # 7 Random select ground truth node correspondences during training if self.training: ref_node_corr_indices, src_node_corr_indices, node_corr_scores = self.coarse_target( gt_node_corr_indices, gt_node_corr_overlaps ) self.stats_history["coarse_matching"].append(time() - t_s) # 7.2 Generate batched node points & feats ref_node_corr_knn_indices = ref_node_knn_indices[ref_node_corr_indices] # (P, K) src_node_corr_knn_indices = src_node_knn_indices[src_node_corr_indices] # (P, K) ref_node_corr_knn_masks = ref_node_knn_masks[ref_node_corr_indices] # (P, K) src_node_corr_knn_masks = src_node_knn_masks[src_node_corr_indices] # (P, K) ref_node_corr_knn_points = ref_node_knn_points[ref_node_corr_indices] # (P, K, 3) src_node_corr_knn_points = src_node_knn_points[src_node_corr_indices] # (P, K, 3) ref_padded_feats_f = torch.cat([ref_feats_f, torch.zeros_like(ref_feats_f[:1])], dim=0) src_padded_feats_f = torch.cat([src_feats_f, torch.zeros_like(src_feats_f[:1])], dim=0) ref_node_corr_knn_feats = index_select( ref_padded_feats_f, ref_node_corr_knn_indices, dim=0 ) # (P, K, C) src_node_corr_knn_feats = index_select( src_padded_feats_f, src_node_corr_knn_indices, dim=0 ) # (P, K, C) # output_dict["ref_node_corr_knn_points"] = ref_node_corr_knn_points # output_dict["src_node_corr_knn_points"] = src_node_corr_knn_points # output_dict["ref_node_corr_knn_masks"] = ref_node_corr_knn_masks # output_dict["src_node_corr_knn_masks"] = src_node_corr_knn_masks # 8. Optimal transport t_s = time() matching_scores = torch.einsum( "bnd,bmd->bnm", ref_node_corr_knn_feats, src_node_corr_knn_feats ) # (P, K, K) matching_scores = matching_scores / feats_f.shape[1] ** 0.5 matching_scores = self.optimal_transport( matching_scores, ref_node_corr_knn_masks, src_node_corr_knn_masks ) self.stats_history["optimal_transport"].append(time() - t_s) # output_dict["matching_scores"] = matching_scores # 9. Generate final correspondences during testing t_s = time() with torch.no_grad(): if not self.fine_matching.use_dustbin: matching_scores = matching_scores[:, :-1, :-1] ref_corr_points, src_corr_points, corr_scores, estimated_transform = self.fine_matching( ref_node_corr_knn_points, src_node_corr_knn_points, ref_node_corr_knn_masks, src_node_corr_knn_masks, matching_scores, node_corr_scores, ) # output_dict["ref_corr_points"] = ref_corr_points # output_dict["src_corr_points"] = src_corr_points # output_dict["corr_scores"] = corr_scores output_dict["estimated_transform"] = estimated_transform self.stats_history["fine_matching"].append(time() - t_s) return output_dict