FlatNCE

FlatNCE provides a way to compute the gradient of InfoNCE without introducing the rounding error when subtracting between two similar numbers.

Specifically, let \(g^\ominus_{ij}\) is the affinity score between reference sample \(x_i\) and negative (noise) sample \(y'_j\); \(g^\oplus_{ii}\) is the affinity score between positive sample and itself/its transformation \(y_j\); \(i\) is the batch index. Denote by \(\hat l_\text{InfoNCE}\) the batch estimate of the loss from InfoNCE: \[ \newcommand{detach}{\mathop{\text{detach}}} \newcommand{logsumexp}{\mathop{\text{logsumexp}}} \begin{gather} \hat l_\text{InfoNCE} = \logsumexp_j g^\ominus_{ij} - g^\oplus_{ii} = \log (\sum_{j \ne i} \exp g^\ominus_{ij} ) -g^\oplus_{ii} \\ \end{gather} \]

Usually, the above is calculated as \[ \hat l_\text{InfoNCE} = \log[\sum_{j \in {\oplus, \ominus} } \exp(g^\ominus_{ij} - \max_k g^\ominus_{ik}) ] + \max_k g^\ominus_{ik} - g^\oplus_{ii} \] When the learning saturates, \(\hat l_\text{InfoNCE}\) goes to \(0\), which means \(\max_k g^\ominus_{ik}\) becomes \(g^\oplus_{ii}\) and thus \[ \hat l_\text{InfoNCE} = \log[\sum_{j \in {\oplus, \ominus} } \exp(g^\ominus_{ij} - g^\oplus_{ii}) ] + \underbrace{g^\oplus_{ii} - g^\oplus_{ii} }_{\text{error-prone}} \] A rounding error will very likely happen when subtracting two near numbers. Such error will accumulate and fail the InfoNCE. As said, FlatNCE provides a way to circumvent this rounding error.

Gradient Perspective

Denote by \(\hat l_\text{FlatNCE}\) the batch estimate of the negative loss from FlatNCE: \[ \begin{aligned} \hat l_\text{FlatNCE} &= \exp [ \logsumexp_{j \ne i} ( g^\ominus_{ij} - g^\oplus_{ii} ) - \detach \logsumexp_{j \ne i} ( g^\ominus_{ij} - g^\oplus_{ii} ) ] \\ &= \frac{\exp \logsumexp_{j \ne i} ( g^\ominus_{ij} - g^\oplus_{ii} ) } {\detach [ \exp \logsumexp_{j \ne i} ( g^\ominus_{ij} - g^\oplus_{ii} ) ] } \\ &= \frac{\exp \log \sum_{j \ne i} \exp [ g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) ]} {\detach \{\exp \log \sum_{j \ne i} \exp [ g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) ] \} } \\ &= \frac{\sum_{j \ne i} \exp [ g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) ]} {\detach \{ \sum_{j \ne i} \exp [ g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) ] \} } \end{aligned} \]

By putting the positive sample into the contrasting samples, \[ \hat l_\text{FlatNCE}^\oplus = \frac{1 + \sum_j \exp \big( g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) \big)} {1 + \text{detach}[\sum_j \exp \big( g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) \big)]} \] where the \(1\) comes from adding the positive sample \(y_i\) to the set of negative samples (let’s denote this “negative” sample by \(y'_0\)). It can be easily found that

\[ \nabla_\theta \hat l_\text{FlatNCE}^\oplus (g_\theta) = \nabla_\theta \hat l_\text{InfoNCE} (g_\theta) \]

We may further find that the gradient of FlatNCE is an importance-weighted estimator of the form \[ \begin{aligned} \nabla_\theta \hat l^\oplus_\text{FlatNCE} &= \frac{\sum_{j \ne i} \{ \exp [ g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) ] [\nabla_\theta g_\theta(x_i, y'_j) - \nabla_\theta g_\theta(x_i, y_i)] \} } {\detach \{ \sum_{j \ne i} \exp [ g_\theta(x_i, y'_j) - g_\theta (x_i, y_i) ] \} } \\ &= \frac{\sum_{j \ne i} \{ \exp [ g_\theta(x_i, y'_j) ] [\nabla_\theta g_\theta(x_i, y'_j) - \nabla_\theta g_\theta(x_i, y_i)] \} } { \sum_{j \ne i} \exp [ g_\theta(x_i, y'_j) ] } \\ &= \sum_{k \ne i} \left\{ \underbrace {\frac{ \exp [ g_\theta(x_i, y'_k) ] } { \sum_{j \ne i} \exp [ g_\theta(x_i, y'_j) ] } }_{w_k} \nabla_\theta g_\theta(x_i, y'_k) \right\} - \nabla_\theta g_\theta(x_i, y_i) \\ \end{aligned} \] As the learning progresses, \(w_k\)’s other than \(w_0\) will go to \(0\); \(w_0\) will go to \(1\), which will cause the gradient to vanish.

Lower-bound Perspective

\(-\hat l_\text{InfoNCE}\) and \(-\hat l_\text{FlatNCE}\) are part of the lower bounds to the mutual information in two methods. Given \(y_0\) the positive sample and \(y_{j>0}\) are the negative samples,

\[ \label{lemma3.3} \begin{aligned} -\hat l^{K, \theta}_\text{InfoNCE} &= -\log \{ \frac 1 K \sum_{j > 0} \exp[g_\theta(x_0,y_j) - g_\theta(x_0,y_0)] \} \\ &= \sup_v (v \frac 1 K \sum_{j > 0} \exp[g_\theta(x_0,y_j) - g_\theta(x_0,y_0)] - (-1 - \log (-v)) \\ &\Downarrow_{v = -e^{-u}} \\ &\ge -e^{-u} \frac 1 K \sum_{j > 0} \exp[g_\theta(x_0,y_j) - g_\theta(x_0,y_0)] - (-1 - \log (e^{-u}) \\ &= 1 - u - \frac 1 K \sum_{j > 0} \exp[g_\theta(x_0,y_j) - g_\theta(x_0,y_0) - u] \end{aligned} \]

Consider \(g_\theta\) as the primal critic and \(u\) as the dual critic. Since arbitrary choice of \(u\) and \(g_\theta\) lower-bounds the mutual information, we can either jointly optimize \(u\) and \(g_\theta\) or more preferably, train in an iterative fashion. Given \(\theta\), set \(u\) to

\[ \hat u(g_\theta) = \log ({\frac 1 K \sum_j \exp[g_\theta(x,y_j) - g_\theta(x, y)]}) \] Then we fix \(u\) and only update \(\theta\). Because \(u\) is fixed, the only gradient comes from \(g_\theta\). Plugin \(\hat u\) to the right-hand side of \(\eqref{lemma3.3}\) to give \[ \begin{align} &1 - u - \frac 1 K \sum_{j > 0} \exp[g_\theta(x_0,y_j) - g_\theta(x_0,y_0) - u] \\ &= 1 - \log ({\frac 1 K \sum_j \exp[g_\theta(x_0,y_j) - g_\theta(x_0, y_0)]}) \notag \\ &\quad - \frac 1 K \frac{\sum_j \exp[g_\theta(x_0,y_j) - g_\theta(x_0,y_0)}{{\frac 1 K \sum_j \exp[g_\theta(x_0,y_j) - g_\theta(x_0, y_0)]}} \\ &= -\log ({\frac 1 K \sum_j \exp[g_\theta(x_0,y_j) - g_\theta(x_0, y_0)]}) \label{obj} \\ &= -\hat l^K_\text{InfoNCE} \end{align} \]

which tightly lower-bounds the \(-\hat l^K_\text{InfoNCE}\). However we update \(\theta\) to \(\theta'\), \(\eqref{lemma3.3}\) will always hold. The objective of the whole FlatNCE is to enlarge \(\eqref{obj}\) after substituting \(u = \hat u(g_\theta')\) so that \(-\hat l^{K, \theta'}_\text{InfoNCE}\) can float up.

External Materials

FlatNCE:小批次对比学习效果差的原因竟是浮点误差? - 科学空间|Scientific Spaces (kexue.fm)

Previous