• 文章
  • 三重态损失 - 进阶介绍
返回机器学习

三重态损失 - 进阶介绍

Yusuf Sarıgöz

·

2022 年 3 月 24 日

Triplet Loss - Advanced Intro

什么是三重态损失?

三重态损失于 2015 年首次在 FaceNet: 人脸识别和聚类的统一嵌入 中提出,自那时以来,它一直是监督相似性或度量学习中最流行的损失函数之一。最简单的解释是,三重态损失鼓励不同类别的样本对与任何同类别的样本对至少保持一定的间隔。数学上,损失值可以计算为 $L=max(d(a,p) - d(a,n) + m, 0)$,其中

  • $p$,即正例,是与 $a$,即锚点,具有相同标签的样本,
  • $n$,即负例,是与 $a$ 标签不同的另一个样本,
  • $d$ 是衡量这三个样本之间距离的函数,
  • 而 $m$ 是一个间隔值,用于保持负例样本之间足够远。

该论文使用欧氏距离,但使用任何其他距离度量(例如余弦距离)同样有效。

该函数有一个学习目标,可以如下所示进行可视化

Triplet Loss learning objective

三重态损失学习目标

请注意,三重态损失没有像对比损失那样强制将锚点和正例样本编码到向量空间中同一点的副作用。这使得三重态损失能够容忍一定的类内方差,而对比损失则强制将锚点与任何正例之间的距离基本上变为 $0$。换句话说,三重态损失允许以某种方式拉伸聚类,以包含离群点,同时仍然确保不同聚类(例如负例对)之间的样本保持一定的间隔。

此外,三重态损失不那么“贪婪”。与对比损失不同,当不同类别的样本很容易与同类别的样本区分开时,三重态损失就满足了。如果没有负例样本的干扰,它不会改变正例聚类中的距离。这是因为三重态损失试图确保负例对之间的距离与正例对之间的距离之间存在一个间隔。然而,对比损失仅在比较不同类别的样本对时考虑间隔值,它根本不关心同类别的样本对当时在哪里。这意味着对比损失可能更早达到局部最优,而三重态损失可能会继续将向量空间组织得更好。

让我们通过动画演示两种损失函数如何组织向量空间。为了简化可视化,向量由二维空间中的点表示,并从正态分布中随机选择。

Animation that shows how Contrastive Loss moves points in the course of training.

显示对比损失在训练过程中如何移动点的动画。

Animation that shows how Triplet Loss moves points in the course of training.

显示三重态损失在训练过程中如何移动点的动画。

从这两种损失函数的数学解释来看,很明显三重态损失在理论上更强大,但三重态损失还有一些额外的技巧可以帮助它更好地工作。最重要的是,三重态损失引入了在线三元组挖掘策略,例如,自动形成最有用的三元组。

为什么三元组挖掘很重要?

三重态损失的公式表明它一次处理三个对象

  • 锚点,
  • 正例 - 与锚点具有相同标签的样本,
  • 负例 - 与锚点和正例具有不同标签的样本。

在一个朴素的实现中,我们可以在每个训练周期的开始形成这样的样本三元组,然后在整个周期中将这些三元组的批次馈送到模型中。这被称为“离线策略”。然而,出于几个原因,这种方法效率不高

  • 需要传递 $3n$ 个样本来获得 $n$ 个三元组的损失值。
  • 并非所有这些三元组都对模型学习有用,例如,产生一个正的损失值。
  • 即使我们在每个训练周期的开始使用本系列中我将要实现的方法之一形成“有用的”三元组,但随着模型权重不断更新,它们可能在周期的某个时刻变得“无用”。

相反,我们可以获取一批 $n$ 个样本及其关联的标签,然后即时形成三元组。这称为“在线策略”。通常情况下,这会产生 $n^3$ 个可能的三元组,但实际上只有这些可能三元组的一个子集是有效的。即使在这种情况下,我们计算损失值所使用的三元组数量也远多于离线策略。

给定一个三元组 (a, p, n),它仅在满足以下条件时有效:

  • ap 具有相同的标签,
  • ap 是不同的样本,
  • 并且 nap 的标签不同。

这些约束条件可能看起来需要使用嵌套循环进行昂贵的计算,但可以通过距离矩阵、掩码和广播等技巧高效地实现。本系列的其余部分将重点介绍这些技巧的实现。

距离矩阵

距离矩阵是一个形状为 $(n, n)$ 的矩阵,用于存储来自两个大小为 $n$ 的集合中所有可能对之间的距离值。这个矩阵可以用于向量化计算,否则需要低效的循环。它的计算也可以优化,我们将实现 Samuel Albanie 解释的 欧氏距离矩阵技巧 (PDF)。您可能希望阅读这份三页的文档以获得该技巧的完整直观理解,但简要解释如下

  • 计算两组向量的点积,例如,在我们这里是嵌入(embeddings)。
  • 从该矩阵中提取对角线,其中包含每个嵌入的欧氏范数平方。
  • 基于以下公式计算欧氏距离平方矩阵:$||a - b||^2 = ||a||^2 - 2 ⟨a, b⟩ + ||b||^2$
  • 对此矩阵取平方根以获得非平方距离。

我们将在 PyTorch 中实现它,所以我们先从导入开始。

import torch
import torch.nn as nn
import torch.nn.functional as F

eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks

def euclidean_distance_matrix(x):
  """Efficient computation of Euclidean distance matrix

  Args:
    x: Input tensor of shape (batch_size, embedding_dim)
    
  Returns:
    Distance matrix of shape (batch_size, batch_size)
  """
  # step 1 - compute the dot product

  # shape: (batch_size, batch_size)
  dot_product = torch.mm(x, x.t())

  # step 2 - extract the squared Euclidean norm from the diagonal

  # shape: (batch_size,)
  squared_norm = torch.diag(dot_product)

  # step 3 - compute squared Euclidean distances

  # shape: (batch_size, batch_size)
  distance_matrix = squared_norm.unsqueeze(0) - 2 * dot_product + squared_norm.unsqueeze(1)

  # get rid of negative distances due to numerical instabilities
  distance_matrix = F.relu(distance_matrix)

  # step 4 - compute the non-squared distances
  
  # handle numerical stability
  # derivative of the square root operation applied to 0 is infinite
  # we need to handle by setting any 0 to eps
  mask = (distance_matrix == 0.0).float()

  # use this mask to set indices with a value of 0 to eps
  distance_matrix += mask * eps

  # now it is safe to get the square root
  distance_matrix = torch.sqrt(distance_matrix)

  # undo the trick for numerical stability
  distance_matrix *= (1.0 - mask)

  return distance_matrix

无效三元组掩码

现在我们可以计算批次中所有可能的嵌入对的距离矩阵了,我们可以应用广播来枚举所有可能三元组的距离差,并将它们表示为一个形状为 (batch_size, batch_size, batch_size) 的张量。然而,如前所述,这 $n^3$ 个三元组中只有一部分是实际有效的,我们需要一个相应的掩码来正确计算损失值。我们将分三个步骤实现这样一个辅助函数

  • 计算不同索引的掩码,例如 (i != j and j != k)
  • 计算有效锚点-正例-负例三元组的掩码,例如 labels[i] == labels[j] and labels[j] != labels[k]
  • 合并两个掩码。
def get_triplet_mask(labels):
  """compute a mask for valid triplets

  Args:
    labels: Batch of integer labels. shape: (batch_size,)

  Returns:
    Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
    A triplet is valid if:
    `labels[i] == labels[j] and labels[i] != labels[k]`
    and `i`, `j`, `k` are different.
  """
  # step 1 - get a mask for distinct indices

  # shape: (batch_size, batch_size)
  indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
  indices_not_equal = torch.logical_not(indices_equal)
  # shape: (batch_size, batch_size, 1)
  i_not_equal_j = indices_not_equal.unsqueeze(2)
  # shape: (batch_size, 1, batch_size)
  i_not_equal_k = indices_not_equal.unsqueeze(1)
  # shape: (1, batch_size, batch_size)
  j_not_equal_k = indices_not_equal.unsqueeze(0)
  # Shape: (batch_size, batch_size, batch_size)
  distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)

  # step 2 - get a mask for valid anchor-positive-negative triplets

  # shape: (batch_size, batch_size)
  labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
  # shape: (batch_size, batch_size, 1)
  i_equal_j = labels_equal.unsqueeze(2)
  # shape: (batch_size, 1, batch_size)
  i_equal_k = labels_equal.unsqueeze(1)
  # shape: (batch_size, batch_size, batch_size)
  valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))

  # step 3 - combine two masks
  mask = torch.logical_and(distinct_indices, valid_indices)

  return mask

在线三元组挖掘的 Batch-all 策略

现在我们准备好实际实现三重态损失本身了。三重态损失涉及几种形成或选择三元组的策略,最简单的一种是使用从批次样本中形成的所有有效三元组。借助我们已经实现的实用函数,这可以通过四个简单的步骤实现

  • 获取批次中嵌入可以形成的所有可能对的距离矩阵。
  • 对该矩阵应用广播以计算所有可能三元组的损失值。
  • 将无效或简单的三元组的损失值设置为 $0$。
  • 平均剩余的正值以返回一个标量损失。

我将从实现这个策略开始,更复杂的策略将在后续的独立文章中介绍。

class BatchAllTtripletLoss(nn.Module):
  """Uses all valid triplets to compute Triplet loss

  Args:
    margin: Margin value in the Triplet Loss equation
  """
  def __init__(self, margin=1.):
    super().__init__()
    self.margin = margin
    
  def forward(self, embeddings, labels):
    """computes loss value.

    Args:
      embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
      labels: Batch of integer labels associated with embeddings. shape: (batch_size,)

    Returns:
      Scalar loss value.
    """
    # step 1 - get distance matrix
    # shape: (batch_size, batch_size)
    distance_matrix = euclidean_distance_matrix(embeddings)

    # step 2 - compute loss values for all triplets by applying broadcasting to distance matrix

    # shape: (batch_size, batch_size, 1)
    anchor_positive_dists = distance_matrix.unsqueeze(2)
    # shape: (batch_size, 1, batch_size)
    anchor_negative_dists = distance_matrix.unsqueeze(1)
    # get loss values for all possible n^3 triplets
    # shape: (batch_size, batch_size, batch_size)
    triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin

    # step 3 - filter out invalid or easy triplets by setting their loss values to 0

    # shape: (batch_size, batch_size, batch_size)
    mask = get_triplet_mask(labels)
    triplet_loss *= mask
    # easy triplets have negative loss values
    triplet_loss = F.relu(triplet_loss)

    # step 4 - compute scalar loss value by averaging positive losses
    num_positive_losses = (triplet_loss > eps).float().sum()
    triplet_loss = triplet_loss.sum() / (num_positive_losses + eps)

    return triplet_loss

结论

我提到三重态损失与对比损失不仅在数学上有所不同,而且在样本选择策略上也有所不同。本文中,我通过使用几种技巧高效地实现了在线三元组挖掘的 Batch-all 策略。

还有其他更复杂的策略,例如 batch-hard 和 batch-semihard 挖掘,但它们的实现以及本文中我用于提高效率的技巧的讨论都值得单独撰写文章。

未来的文章将涵盖这些主题,并进一步讨论一些避免向量坍缩以及控制类内和类间方差的技巧。

此页面有用吗?

感谢您的反馈! 🙏

很抱歉听到此。😔 您可以在 GitHub 上编辑此页面,或者创建一个 GitHub Issue。