Source code for opr.losses.batch_hard_triplet_margin

"""Multimodal triplet margin loss implementation.

Code adopted from repository: https://github.com/jac99/MinkLocMultimodal, MIT License
"""

from typing import Dict, Tuple

from pytorch_metric_learning.distances import LpDistance
from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from torch import Tensor, nn

from opr.miners import BatchHardTripletMiner


[docs] class BatchHardTripletMarginLoss(nn.Module): """Triplet margin loss with batch hard triplet miner. Code adopted from repository: https://github.com/jac99/MinkLocMultimodal, MIT License """ def __init__(self, margin: float = 0.2) -> None: """Triplet margin loss with batch hard triplet miner. Args: margin (float): Margin value for TripletMarginLoss. Defaults to 0.2. """ super().__init__() self.margin = margin distance = LpDistance(normalize_embeddings=False, collect_stats=True) reducer_fn = AvgNonZeroReducer(collect_stats=True) self.miner_fn = BatchHardTripletMiner(distance=distance) self.loss_fn = TripletMarginLoss( margin=self.margin, swap=True, distance=distance, reducer=reducer_fn, collect_stats=True )
[docs] def forward( # noqa: D102 self, embeddings: Tensor, positives_mask: Tensor, negatives_mask: Tensor ) -> Tuple[Tensor, Dict[str, float]]: hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) miner_stats = self.miner_fn.stats loss = self.loss_fn(embeddings, indices_tuple=hard_triplets) stats = { "loss": loss.item(), "avg_embedding_norm": self.loss_fn.distance.final_avg_query_norm, "num_triplets": len(hard_triplets[0]), "num_non_zero_triplets": float(self.loss_fn.reducer.num_past_filter), } try: stats["non_zero_rate"] = stats["num_non_zero_triplets"] / stats["num_triplets"] except ZeroDivisionError: print("WARNING: encoutered batch with 'num_triplets' == 0.") stats["non_zero_rate"] = 1.0 stats.update(miner_stats) return loss, stats