Source code for opr.losses.batch_hard_contrastive

"""Multimodal contrastive loss implementation.

Code adopted from repository: https://github.com/jac99/MinkLocMultimodal, MIT License
"""

from typing import Dict, Tuple

from pytorch_metric_learning.distances import LpDistance
from pytorch_metric_learning.losses import ContrastiveLoss
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from torch import Tensor, nn

from opr.miners import BatchHardTripletMiner


[docs] class BatchHardContrastiveLoss(nn.Module): """Contrastive loss with batch hard triplet miner. Code adopted from repository: https://github.com/jac99/MinkLocMultimodal, MIT License """ def __init__(self, pos_margin: float = 0.2, neg_margin: float = 0.2): """ Initializes the BatchHardContrastiveLoss module. Args: pos_margin (float): Margin value for positive pairs in ContrastiveLoss. Defaults to 0.2. neg_margin (float): Margin value for negative pairs in ContrastiveLoss. Defaults to 0.2. """ super().__init__() self.pos_margin = pos_margin self.neg_margin = neg_margin self.distance = LpDistance(normalize_embeddings=False, collect_stats=True) self.miner_fn = BatchHardTripletMiner(distance=self.distance) reducer_fn = AvgNonZeroReducer(collect_stats=True) self.loss_fn = ContrastiveLoss( pos_margin=self.pos_margin, neg_margin=self.neg_margin, distance=self.distance, reducer=reducer_fn, collect_stats=True, )
[docs] def forward( # noqa: D102 self, embeddings: Tensor, positives_mask: Tensor, negatives_mask: Tensor ) -> Tuple[Tensor, Dict[str, float]]: hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) miner_stats = self.miner_fn.stats loss = self.loss_fn(embeddings, indices_tuple=hard_triplets) stats = { "loss": loss.item(), "avg_embedding_norm": self.loss_fn.distance.final_avg_query_norm, "pos_pairs_above_threshold": self.loss_fn.reducer.reducers["pos_loss"].num_past_filter, "neg_pairs_above_threshold": self.loss_fn.reducer.reducers["neg_loss"].num_past_filter, "num_pairs": 2 * len(hard_triplets[0]), } try: stats["non_zero_rate"] = ( stats["pos_pairs_above_threshold"] + stats["neg_pairs_above_threshold"] ) / stats["num_pairs"] except ZeroDivisionError: print("WARNING: encoutered batch with 'num_pairs' == 0.") stats["non_zero_rate"] = 1.0 stats.update(miner_stats) return loss, stats