Notes on SimCLR - A Simple Framework for Contrastive Learning of Visual Representations
Ting Chen et al. purpose a new framework, SimCLR, based on a self-supervised contrastive-based learning of visual representations.
A Simple Framework for Contrastive Learning of Visual Representations1
This work1 shows:
- Composition of data augmentations
- NT-Xent, Normalized Temperature-scaled Cross-Entropy, a learnable nonlinear transformation between the representation and the contrastive loss
- Contrastive learning benefits from larger batch sizes
The method
Composition of data augmentations
A stochastic data augmentation module that transforms any given data example randomly resulting in two correlated views of the same example denoted $\tilde{x}_i \text{and } \tilde{x}_j$, which we consider as a positive pair.
Composition of multiple data augmentation operations is crucial in defining the contrastive prediction tasks that yield effective representations. (...)
$x$ is the input image
$\tilde{x}_i \text{and } \tilde{x}_j$ are transformed images
$\mathcal{t}$ is a sequentially applied simple augmentations
$\mathcal{T}$ is a set of simple augmentations
The contrastive loss function
$$ \mathsf{sim}(u,v) = \dfrac{u^\intercal v}{\|u\|\|v\|} $$
$$ \ $$
$$ \ell_{i,j} = -\log \dfrac{\exp \mathsf{sim}(z_i, z_j) / \tau}{ \sum_{k=1}^{2N} \mathbb{1}_{[k \ne i]} \exp \mathsf{sim}(z_i, z_j) / \tau} , \ \ \ \ (1) $$
Given a set $\{x_k\}$ including a positive pair of examples $x_i$ and $x_j$, the contrastive prediction task aims to identify $x_j$ in $\{\tilde{x}_k\}_{k \neq i}$ for a given $\tilde{x}_i$.
The final loss is computed across all positive pairs, both $(i, j)$ and $(j, i)$, in a mini-batch/
A tale of large batch sizes
We do not train the model with a memory bank (Wu et al.,2018). Instead, we vary the training batch size N from 256 to 8192.> Training with large batch size may be unstable when using standard SGD/Momentum with linear learning rate scaling (Goyal et al., 2017). To stabilize the training, we use the LARS2 optimizer (You et al., 2017) for all batch sizes.
2. LARS, Large Batch Training of Convolutional Networks↩
Dataset and metrics
Most of our study for unsupervised pretraining (learning encoder network f without labels) is done using the ImageNet ILSVRC-2012 dataset (...)
To evaluate the learned representations, we follow the widely used linear evaluation protocol (Zhang et al., 2016; Oord et al., 2018; Bachman et al., 2019; Kolesnikov et al., 2019), where a linear classifier is trained on top of the frozen base network, and test accuracy is used as a proxy for representation quality.
Beyond linear evaluation, we also compare against state-of-the-art on semi-supervised and transfer learning.
Default setting
(...) for data augmentation we use random crop and resize (with random flip), color distortions, and Gaussian blur (..)
- $\mathcal T = \{\text{crop and resize, color distortions, and Gaussian blur}\}$
We use ResNet-50 as the base encoder network (...)
- $h_i = f(\tilde x_i) = \text{ResNet}(\tilde{x}_i)$ where $h_i \in \mathbb{R}^{2048}$
(...) 2-layer MLP projection head to project the representation to a 128-dimensional latent space (...)
- $z_i = g(h_i) = W^{(2)}\sigma(W^{(1)}h_i),\ \sigma = \text{ReLU}$ where $z_i \in \mathbb{R}^{128}$
(...) optimized using LARS with linear learning rate scaling (i.e. $\text{LearningRate = 0.3 × BatchSize/256}$) and weight decay of $10^{-6}$. We train at batch size 4096 for 100 epochs (...)
Data Augmentation for Contrastive Representation Learning
Composition of data augmentation operations is crucial for learning good representations
The authors decided to consider several common augmentations:
- spatial/geometric transformation - cropping, resizing, horizontal flipping, rotation and cutout
- appearance transformation - color distortion (including color dropping, brightness, contrast, saturation,hue), Gaussian blur, and Sobel filtering
We observe that no single transformation suffices to learn good representations (...) When composing augmentations, the contrastive prediction task becomes harder, but the quality of representation improves dramatically
Normalized cross entropy loss with adjustable temperature works better than alternatives
Contrastive learning benefits (more) from larger batch sizes and longer training
Comparison with State-of-the-art
Not only does SimCLR outperform previous work (Figure 1), but it is also simpler, requiring neither specialized architectures (...)
Conclusions
In this work, we present a simple framework and its instantiation for contrastive visual representation learning. We carefully study its components, and show the effects of different design choices.
Our approach differs from standard supervised learning on ImageNet only in the choice of data augmentation, the use of a nonlinear head at the end of the network, and the loss function.
The strength of this simple framework suggests that, despite a recent surge in interest, self-supervised learning remains undervalued.
Code samples
Image preprocessing
$$ \begin{aligned} \textbf{for}& \text{ sampled minibatch } \{x\}^{N}_{k=1} \textbf{ do} \\ \textbf{for}& \text{ all k} \in \{1, ... , N\} \textbf{ do} \\ &\text{draw two augmentation functions } \mathcal{t} \sim \mathcal{T}, \mathcal{t'} \sim \mathcal{T} \\ &\tilde{x}_{2k-1} = \mathcal{t(x_k)} \; \; \tilde{x}_{2k} = \mathcal{t'(x_k)} \end{aligned} $$preprocess_for_train
can be seen as $\mathcal{T}$ in the training setting. For additional implementation details, see simclr/data_util.py3.
import tensorflow as tf
from data_util import random_crop_with_resize, random_color_jitter
def preprocess_for_train(image, height, width,
color_distort=True, crop=True, flip=True):
"""Preprocesses the given image for training.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
color_distort: Whether to apply the color distortion.
crop: Whether to crop the image.
flip: Whether or not to flip left and right of an image.
Returns:
A preprocessed image `Tensor`.
"""
if crop:
image = random_crop_with_resize(image, height, width)
if flip:
image = tf.image.random_flip_left_right(image)
if color_distort:
image = random_color_jitter(image)
image = tf.reshape(image, [height, width, 3])
return image
Where x
and v
can be $z_i$ and $z_j$, in no particular order. For additional implementation details, see simclr/objective.py.
LARGE_NUM = 1e9
def nt_xent_loss(x, v, temperature=1.0):
batch_size = tf.shape(x)[0]
masks = tf.one_hot(tf.range(batch_size), batch_size)
labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
logits_x_x = tf.matmul(x, x, transpose_b=True) / temperature
logits_x_x = logits_x_x - masks * LARGE_NUM
logits_v_v = tf.matmul(v, v, transpose_b=True) / temperature
logits_v_v = logits_v_v - masks * LARGE_NUM
logits_x_v = tf.matmul(x, v, transpose_b=True) / temperature
logits_v_x = tf.matmul(v, x, transpose_b=True) / temperature
loss_x = tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_x_v, logits_x_x], 1))
loss_v = tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_v_x, logits_v_v], 1))
loss = tf.reduce_mean(loss_x + loss_v)
return loss
1. A Simple Framework for Contrastive Learning of Visual Representations. Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. - 2020↩
2. LARS, Large Batch Training of Convolutional Networks↩
3. google-research/simclr - https://github.com/google-research/simclr↩