Source code for opr.modules.eca

"""Implementation of Efficient Channel Attention ECA block.

Wang, Qilong, et al. "ECA-Net: Efficient channel attention for deep convolutional neural networks."
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020.

Paper: https://arxiv.org/abs/1910.03151
Code for Mink version adopted from the repository: https://github.com/jac99/MinkLoc3Dv2, MIT License
"""

from __future__ import annotations

import importlib
from typing import Optional

import numpy as np
from torch import nn

from opr.optional_deps import lazy

# Lazy-load MinkowskiEngine - will return real module or helpful stub
ME = lazy("MinkowskiEngine", feature="sparse convolutions")

# For inheritance, we need the actual class - this will fail gracefully if ME is a stub
try:
    BasicBlock = importlib.import_module("MinkowskiEngine.modules.resnet_block").BasicBlock
except RuntimeError:
    # When ME is a stub, accessing .modules.resnet_block.BasicBlock fails
    # So we create a minimal fallback
    class BasicBlock:  # type: ignore[misc]
        """Fallback BasicBlock that gives helpful error when MinkowskiEngine missing."""

        def __init__(self, *args: object, **kwargs: object) -> None:
            """Raise helpful error about missing MinkowskiEngine."""
            raise RuntimeError(
                "MinkowskiEngine required for sparse convolutions.\n"
                "See the documentation for installation instructions"
            )


[docs] class MinkECALayer(nn.Module): """Efficient Channel Attention layer for sparse tensors with MinkowskiEngine.""" def __init__(self, channels: int, gamma: int = 2, b: int = 1) -> None: """Efficient Channel Attention layer for sparse tensors with MinkowskiEngine. Original paper: https://arxiv.org/abs/1910.03151 Args: channels (int): Number of channels in input. gamma (int): Gamma parameter, see paper for more details. Defaults to 2. b (int): b parameter, see paper for more details. Defaults to 1. """ super().__init__() t = int(abs((np.log2(channels) + b) / gamma)) k_size = t if t % 2 else t + 1 self.avg_pool = ME.MinkowskiGlobalPooling() self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() self.broadcast_mul = ME.MinkowskiBroadcastMultiplication()
[docs] def forward(self, x: object) -> object: # type: ignore[misc] """Forward pass of ECA layer. Args: x: Input sparse tensor from MinkowskiEngine. Returns: Output sparse tensor with attention applied. """ # feature descriptor on the global spatial information y_sparse = self.avg_pool(x) # Apply 1D convolution along the channel dimension y = self.conv(y_sparse.F.unsqueeze(1)).squeeze(1) # y is (batch_size, channels) tensor y = self.sigmoid(y) # y is (batch_size, channels) tensor y_sparse = ME.SparseTensor( y, coordinate_manager=y_sparse.coordinate_manager, coordinate_map_key=y_sparse.coordinate_map_key ) # y must be features reduced to the origin return self.broadcast_mul(x, y_sparse)
[docs] class MinkECABasicBlock(BasicBlock): """Efficient Channel Attention BasicBlock for ResNet with MinkowskiEngine.""" def __init__( self, inplanes: int, planes: int, stride: int = 1, dilation: int = 1, downsample: Optional[nn.Module] = None, dimension: int = 3, ) -> None: """Efficient Channel Attention BasicBlock for ResNet with MinkowskiEngine. Args: inplanes (int): Number of input channels. planes (int): Number of output channels. stride (int): Convolution stride size. Defaults to 1. dilation (int): Convolution dilation size. Defaults to 1. downsample (nn.Module, optional): Downsample layer, if needed. Defaults to None. dimension (int): Number of dimensions. Defaults to 3. """ super().__init__( inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, dimension=dimension ) self.eca = MinkECALayer(planes, gamma=2, b=1)
[docs] def forward(self, x: object) -> object: # type: ignore[misc] """Forward pass of ECA BasicBlock. Args: x: Input sparse tensor from MinkowskiEngine. Returns: Output sparse tensor with ECA attention applied. """ residual = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) out = self.conv2(out) out = self.norm2(out) out = self.eca(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out