Source code for opr.modules.fusion

"""Basic fusion modules implementation."""
from typing import Dict

import torch
from torch import Tensor, nn

from .gem import SeqGeM


[docs] class Concat(nn.Module): """Concatenation module.""" def __init__(self) -> None: """Concatenation module.""" super().__init__()
[docs] def forward(self, data: Dict[str, Tensor]) -> Tensor: # noqa: D102 data = {key: value for key, value in data.items() if value is not None} fusion_global_descriptor = torch.concat(list(data.values()), dim=1) return fusion_global_descriptor
[docs] class Add(nn.Module): """Addition module.""" def __init__(self) -> None: """Addition module.""" super().__init__()
[docs] def forward(self, data: Dict[str, Tensor]) -> Tensor: # noqa: D102 data = {key: value for key, value in data.items() if value is not None} fusion_global_descriptor = torch.stack(list(data.values()), dim=0).sum(dim=0) if len(fusion_global_descriptor.shape) == 1: fusion_global_descriptor = fusion_global_descriptor.unsqueeze(0) return fusion_global_descriptor
[docs] class GeMFusion(nn.Module): """GeM fusion module.""" def __init__(self, p: int = 3, eps: float = 1e-6) -> None: """Generalized-Mean fusion module. Args: p (int): Initial value of learnable parameter 'p', see paper for more details. Defaults to 3. eps (float): Negative values will be clamped to `eps` (ReLU). Defaults to 1e-6. """ super().__init__() self.gem = SeqGeM(p=p, eps=eps)
[docs] def forward(self, data: Dict[str, Tensor]) -> Tensor: # noqa: D102 data = {key: value for key, value in data.items() if value is not None} descriptors = list(data.values()) descriptors = torch.stack(descriptors, dim=len(descriptors[0].shape)) out = self.gem(descriptors) if len(out.shape) == 1: out = out.unsqueeze(0) return out