A Simple Framework for Contrastive Learning of Visual Representations1

Note: All quotes and images used in this page belong to the authors of the paper. If you intend to use them, please give the deserved credits. Furthermore, all images and tables captions map to the captions of the paper.

This work1 shows:

  • Composition of data augmentations
  • NT-Xent, Normalized Temperature-scaled Cross-Entropy, a learnable nonlinear transformation between the representation and the contrastive loss
  • Contrastive learning benefits from larger batch sizes

The method

Composition of data augmentations

A stochastic data augmentation module that transforms any given data example randomly resulting in two correlated views of the same example denoted $\tilde{x}_i \text{and } \tilde{x}_j$, which we consider as a positive pair.

Composition of multiple data augmentation operations is crucial in defining the contrastive prediction tasks that yield effective representations. (...)

Figure 2. (partial) A simple framework for contrastive learning of visual representations

  • $x$ is the input image

  • $\tilde{x}_i \text{and } \tilde{x}_j$ are transformed images

  • $\mathcal{t}$ is a sequentially applied simple augmentations

  • $\mathcal{T}$ is a set of simple augmentations

The base encoder

Figure 2. (partial) A simple framework for contrastive learning of visual representations

  • $f(\cdot)$ is a neural network base encoder

  • $h_i$ and $h_j$ is the output after the average pooling layer where $h_i \in \mathbb{R}^{d}$

The projection head

Figure 2. A simple framework for contrastive learning of visual representations

  • A small neural network projection head $g(\cdot)$ that maps representations to the space where contrastive loss is applied.

The contrastive loss function

$$ \mathsf{sim}(u,v) = \dfrac{u^\intercal v}{\|u\|\|v\|} $$

$$ \ $$

$$ \ell_{i,j} = -\log \dfrac{\exp \mathsf{sim}(z_i, z_j) / \tau}{ \sum_{k=1}^{2N} \mathbb{1}_{[k \ne i]} \exp \mathsf{sim}(z_i, z_j) / \tau} , \ \ \ \ (1) $$

Given a set $\{x_k\}$ including a positive pair of examples $x_i$ and $x_j$, the contrastive prediction task aims to identify $x_j$ in $\{\tilde{x}_k\}_{k \neq i}$ for a given $\tilde{x}_i$.

The final loss is computed across all positive pairs, both $(i, j)$ and $(j, i)$, in a mini-batch/

A tale of large batch sizes

We do not train the model with a memory bank (Wu et al.,2018). Instead, we vary the training batch size N from 256 to 8192.> Training with large batch size may be unstable when using standard SGD/Momentum with linear learning rate scaling (Goyal et al., 2017). To stabilize the training, we use the LARS2 optimizer (You et al., 2017) for all batch sizes.

2. LARS, Large Batch Training of Convolutional Networks

Dataset and metrics

Most of our study for unsupervised pretraining (learning encoder network f without labels) is done using the ImageNet ILSVRC-2012 dataset (...)

To evaluate the learned representations, we follow the widely used linear evaluation protocol (Zhang et al., 2016; Oord et al., 2018; Bachman et al., 2019; Kolesnikov et al., 2019), where a linear classifier is trained on top of the frozen base network, and test accuracy is used as a proxy for representation quality.

Beyond linear evaluation, we also compare against state-of-the-art on semi-supervised and transfer learning.

Default setting

(...) for data augmentation we use random crop and resize (with random flip), color distortions, and Gaussian blur (..)

  • $\mathcal T = \{\text{crop and resize, color distortions, and Gaussian blur}\}$

We use ResNet-50 as the base encoder network (...)

  • $h_i = f(\tilde x_i) = \text{ResNet}(\tilde{x}_i)$ where $h_i \in \mathbb{R}^{2048}$

(...) 2-layer MLP projection head to project the representation to a 128-dimensional latent space (...)

  • $z_i = g(h_i) = W^{(2)}\sigma(W^{(1)}h_i),\ \sigma = \text{ReLU}$ where $z_i \in \mathbb{R}^{128}$

(...) optimized using LARS with linear learning rate scaling (i.e. $\text{LearningRate = 0.3 × BatchSize/256}$) and weight decay of $10^{-6}$. We train at batch size 4096 for 100 epochs (...)

Data Augmentation for Contrastive Representation Learning

Composition of data augmentation operations is crucial for learning good representations

The authors decided to consider several common augmentations:

  • spatial/geometric transformation - cropping, resizing, horizontal flipping, rotation and cutout
  • appearance transformation - color distortion (including color dropping, brightness, contrast, saturation,hue), Gaussian blur, and Sobel filtering

Figure 5. Linear evaluation (ImageNet top-1 accuracy) under individual or composition of data augmentations, applied only to one branch.

We observe that no single transformation suffices to learn good representations (...) When composing augmentations, the contrastive prediction task becomes harder, but the quality of representation improves dramatically

Random cropping and random color distortion

Figure 6. Histograms of pixel intensities (over all channels) for different crops of two different images.

Neural nets may exploit this shortcut to solve the predictive task. Therefore, it is critical to compose cropping with color distortion in order to learn generalizable features.

Findings

A nonlinear projection head improves the representation quality of the layer before it

Figure 8. Linear evaluation of representations with different projection heads $g(\cdot)$ and various dimensions of $z = g(h)$.

Normalized cross entropy loss with adjustable temperature works better than alternatives

Table 4. Linear evaluation (top-1) for models trained with different loss functions. “sh” means using semi-hard negative mining.

Table 5. Linear evaluation for models trained with different choices of $\ell_{2}$ norm and temperature $\tau$ for NT-Xent loss.

Contrastive learning benefits (more) from larger batch sizes and longer training

Figure 9. Linear evaluation models (ResNet-50) trained with different batch size and epochs. Each bar is a single run from scratch.

Comparison with State-of-the-art

Table 8. Comparison of transfer learning performance of our self-supervised approach with supervised baselines across 12 natural image classification datasets, for ResNet-50 (4×) models pretrained on ImageNet.

Not only does SimCLR outperform previous work (Figure 1), but it is also simpler, requiring neither specialized architectures (...)

Figure 1. ImageNet Top-1 accuracy of linear classifiers trained on representations learned with different self-supervised methods (pretrained on ImageNet).

Conclusions

In this work, we present a simple framework and its instantiation for contrastive visual representation learning. We carefully study its components, and show the effects of different design choices.

Our approach differs from standard supervised learning on ImageNet only in the choice of data augmentation, the use of a nonlinear head at the end of the network, and the loss function.

The strength of this simple framework suggests that, despite a recent surge in interest, self-supervised learning remains undervalued.

Code samples

Note: In this section I am highlighting some rather important blocks of code from the SimCLR framework using TensorFlow. These blocks taken from google-research/simclr with some minor changes, for simplicity.

Image preprocessing

$$ \begin{aligned} \textbf{for}& \text{ sampled minibatch } \{x\}^{N}_{k=1} \textbf{ do} \\ \textbf{for}& \text{ all k} \in \{1, ... , N\} \textbf{ do} \\ &\text{draw two augmentation functions } \mathcal{t} \sim \mathcal{T}, \mathcal{t'} \sim \mathcal{T} \\ &\tilde{x}_{2k-1} = \mathcal{t(x_k)} \; \; \tilde{x}_{2k} = \mathcal{t'(x_k)} \end{aligned} $$

preprocess_for_train can be seen as $\mathcal{T}$ in the training setting. For additional implementation details, see simclr/data_util.py3.

import tensorflow as tf
from data_util import random_crop_with_resize, random_color_jitter

def preprocess_for_train(image, height, width,
                     color_distort=True, crop=True, flip=True):
    """Preprocesses the given image for training.
    Args:
    image: `Tensor` representing an image of arbitrary size.
    height: Height of output image.
    width: Width of output image.
    color_distort: Whether to apply the color distortion.
    crop: Whether to crop the image.
    flip: Whether or not to flip left and right of an image.
    Returns:
    A preprocessed image `Tensor`.
    """
    if crop:
        image = random_crop_with_resize(image, height, width)
    if flip:
        image = tf.image.random_flip_left_right(image)
    if color_distort:
        image = random_color_jitter(image)
        image = tf.reshape(image, [height, width, 3])
    return image

NT-Xent loss

$$ \mathsf{sim}(u,v) = \dfrac{u^\intercal v}{\|u\|\|v\|} $$

$$ \ $$

$$ \ell_{i,j} = -\log \dfrac{\exp \mathsf{sim}(z_i, z_j) / \tau}{ \sum_{k=1}^{2N} \mathbb{1}_{[k \ne i]} \exp \mathsf{sim}(z_i, z_j) / \tau} , \ \ \ \ (1) $$

Where x and v can be $z_i$ and $z_j$, in no particular order. For additional implementation details, see simclr/objective.py.

LARGE_NUM = 1e9

def nt_xent_loss(x, v, temperature=1.0):

    batch_size = tf.shape(x)[0]
    masks = tf.one_hot(tf.range(batch_size), batch_size)
    labels = tf.one_hot(tf.range(batch_size), batch_size * 2)

    logits_x_x = tf.matmul(x, x, transpose_b=True) / temperature
    logits_x_x = logits_x_x - masks * LARGE_NUM

    logits_v_v = tf.matmul(v, v, transpose_b=True) / temperature
    logits_v_v = logits_v_v - masks * LARGE_NUM

    logits_x_v = tf.matmul(x, v, transpose_b=True) / temperature
    logits_v_x = tf.matmul(v, x, transpose_b=True) / temperature

    loss_x = tf.nn.softmax_cross_entropy_with_logits(
        labels, tf.concat([logits_x_v, logits_x_x], 1))
    loss_v = tf.nn.softmax_cross_entropy_with_logits(
        labels, tf.concat([logits_v_x, logits_v_v], 1))

    loss = tf.reduce_mean(loss_x + loss_v)

    return loss

Acknowledgements

I would like to thank the authors for this incredibly well written and clear contribution.

References

1. A Simple Framework for Contrastive Learning of Visual Representations. Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. - 2020

2. LARS, Large Batch Training of Convolutional Networks