Shortcuts

Source code for quaterion.loss.online_contrastive_loss

from typing import Optional

import torch
import torch.nn.functional as F

from quaterion.distances import Distance
from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import (
    get_anchor_negative_mask,
    get_anchor_positive_mask,
    max_value_of_dtype,
)


[docs]class OnlineContrastiveLoss(GroupLoss): """Implements Contrastive Loss as defined in http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf Unlike :class:`~quaterion.loss.contrastive_loss.ContrastiveLoss`, this one supports online pair mining, i.e., it makes positive and negative pairs on-the-fly, so you don't need to form such pairs yourself. Instead, it first calculates all possible pairs in a batch, and then filters valid positive pairs and valid negative pairs separately. Batch-all and batch-hard strategies for online pair mining are supported. Args: margin: Margin value to push negative examples apart. Optional, defaults to `0.5`. distance_metric_name: Name of the distance function, e.g., :class:`~quaterion.distances.Distance`. Optional, defaults to :attr:`~quaterion.distances.Distance.COSINE`. mining (str, optional): Pair mining strategy. One of `"all"`, `"hard"`. Defaults to `"hard"`. """ def __init__( self, margin: Optional[float] = 0.5, distance_metric_name: Distance = Distance.COSINE, mining: Optional[str] = "hard", ): mining_types = ["all", "hard"] if mining not in mining_types: raise ValueError( f"Unrecognized mining strategy: {mining}. Must be one of {', '.join(mining_types)}" ) super(OnlineContrastiveLoss, self).__init__( distance_metric_name=distance_metric_name ) self._margin = margin self._mining = mining
[docs] def get_config_dict(self): config = super().get_config_dict() config.update( { "margin": self._margin, "mining": self._mining, } ) return config
[docs] def forward( self, embeddings: torch.Tensor, groups: torch.LongTensor ) -> torch.Tensor: """Calculates Contrastive Loss by making pairs on-the-fly. Args: embeddings (torch.Tensor): Batch of embeddings. Shape: (batch_size, embedding_dim) groups (torch.LongTensor): Batch of labels associated with `embeddings`. Shape: (batch_size,) Returns: torch.Tensor: Scalar loss value. """ # Shape: (batch_size, batch_size) dists = self.distance_metric.distance_matrix(embeddings) # get a mask for valid anchor-positive pairs and apply it to the distance matrix # to set invalid ones to 0 anchor_positive_mask = get_anchor_positive_mask(groups).float() anchor_positive_dists = anchor_positive_mask * dists # invalid pairs set to 0 # get a mask for valid anchor-negative pairs, and apply it to distance matrix # # to set invalid ones to a maximum value of dtype anchor_negative_mask = get_anchor_negative_mask(groups) anchor_negative_dists = dists anchor_negative_dists[~anchor_negative_mask] = max_value_of_dtype( anchor_negative_dists.dtype ) if self._mining == "all": num_positive_pairs = anchor_positive_mask.sum() positive_loss = anchor_positive_dists.sum() / torch.max( num_positive_pairs, torch.tensor(1e-16) ) num_negative_pairs = anchor_negative_mask.float().sum() negative_loss = F.relu( self._margin - anchor_negative_dists ).sum() / torch.max(num_negative_pairs, torch.tensor(1e-16)) else: # batch-hard pair mining # get the hardest positive for each anchor # shape: (batch_size,) hardest_positive_dists = anchor_positive_dists.max(dim=1)[0] num_positive_pairs = torch.count_nonzero(hardest_positive_dists) positive_loss = hardest_positive_dists.sum() / torch.max( num_positive_pairs, torch.tensor(1e-16) ) # get the hardest negative for each anchor # shape (batch_size,) hardest_negative_dists = anchor_negative_dists.min(dim=1)[0] num_negative_pairs = torch.sum( ( hardest_negative_dists < max_value_of_dtype( hardest_negative_dists.dtype ) # It's True where we didn't set to this maximum value to mark them invalid ).float() ) negative_loss = F.relu( self._margin - hardest_negative_dists ).sum() / torch.max(num_negative_pairs, torch.tensor(1e-16)) total_loss = 0.5 * (positive_loss + negative_loss) return total_loss

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community