Contrastive Predictive Coding

Rationale

CPC was initially proposed in autoregressive models. It enhances the autoencoder by lifting the lower bound of mutual information between the encoded representation and the original data. The original data can either be the data before the encoding, or the future data after various steps.

CPC learns the representation by minimizing the following loss function: \[ \newcommand{\c}{\mathrm{c}} \label{loss} \mathcal L_N = -\E_{t \sim \Phi} \log \frac{f_\theta(\x_l,\c)} {\sum_{\x^\prime \in X}f_\theta(\x^\prime, \c)} = -\E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} \log \frac{f_\theta(\x_l,\c)} {\sum_{\x^\prime \in X}f_\theta(\x^\prime, \c)} \] \(t = (\x_1, \dots, \x_{N+1}, \c^*; \ell)\) is a tuple of random variables and \(\Phi\) is the distribution from which \(t\) is drawn. \((\x:\c)_{1:N+1}\) are drawn from the joint distribution \(\tilde p(\x,\c)\). All \(\c_i\)’s but one randomly-chosen \(\c^*\) are trimmed from the original samples. \(\c^*\) is known but it is unknown which sample it is associated with. \(\ell\) denotes the index of this unique sample we are trying to predict. In essence, \(\theta\) is parameterizing the representation \(\c\). The formal score function \(f_\theta\) is simply chosen to be a deterministic cosine similarity between \(\x\) and \(\c\).

\(\x_{1:N+1}\) consists of one positive sample \(\x^*\) that is matched with \(\c^*\) and more other independent negative (noise) samples \(\x_i\)’s that are not matched with \(\c\). \(N\) is the fixed ratio of the number of negative samples to the number of positive samples. Let \(P(\ell=i|\x_{1:N+1},\c^*)\) represent the probability that \(\x_i\) is the positive sample given \(\x_1, \dots x_{N+1}\) and \(\c^*\). \[ \begin{aligned} P(\ell=i|\x_{1:N+1},\c^*) &= \frac{P(\ell=i, \x_{1:N+1},\c^*)}{\sum_{j=1}^{N+1} P(\ell=j, \x_{1:N+1},\c^*)} \\ \end{aligned} \]

Substitute \(P(\ell=i,\x_{1:N+1},\c^*) = P(\x_i,\c^*) \prod_{j=1,j \ne i}^{N+1} P(\x_j)\) to give \[ \begin{aligned} P(\ell=i|\x_{1:N+1},\c^*) &= \frac{P(\x_i,\c^*) \prod_{j=1,j \ne i}^{N+1} P(\x_j)}{\sum_{j=1}^{N+1} [P(\x_j,\c^*) \prod_{k=1,k \ne j}^{N+1} P(\x_k)]} \\ &= \frac{\tilde p(\x_i,\c^*) \prod_{j=1,j \ne i}^{N+1} \tilde p_X(\x_j)} {\sum_{j=1}^{N+1} [\tilde p(\x_j,\c^*) \prod_{k=1,k \ne j}^{N+1} \tilde p_X(\x_j)]} \\ &= \frac{\frac{\tilde p(\x_i,\c^*)}{\tilde p_X(\x_i)} } {\sum_{j=1}^{N+1} \frac{\tilde p(\x_j,\c^*)}{\tilde p_X(\x_j)}} \\ \end{aligned} \] The loss function \(\eqref{loss}\) is in fact the expectation (the outer \(\E\)) of the categorical cross entropy (the inner \(\E\)) of identifying the sample as positive or negative. The minimum of loss function is thus reached when the two categorical distributions are identical. That is, \[ \begin{gather} P_\theta(l = i|x_{1:N+1},\c^*) = \frac{f_\theta(\x_i,\c)}{\sum_{\x^\prime \in X}f_\theta(\x^\prime, \c)} = \frac{\frac{\tilde p(\x_i,\c^*)}{\tilde p_X(\x_i)} } {\sum_{j=1}^{N+1} \frac{\tilde p(\x_j,\c^*)}{\tilde p_X(\x_j)}} = P(\ell=i|\x_{1:N+1},\c^*) \\ f_\theta(\x_i,\c) = \frac{\sum_{\x^\prime \in X}f_\theta(\x^\prime, \c)}{\sum_{\x^\prime \in X} \frac{\tilde p(\x^\prime|\c)}{\tilde p_X(\x^\prime)}} \frac{\tilde p(\x_i,\c^*)}{\tilde p_X(\x_i)} \\ f_\theta(\x_i,\c) \propto \frac{\tilde p(\x_i,\c^*)}{\tilde p_X(\x_i)} \end{gather} \]

Bounding the Mutual Information

CPC helps estimate the lower bound of the mutual information between the encoded representation and the original data when optimizing the InfoNCE loss: \[ \begin{aligned} &\mathcal L_N^{\text{opt}} = -\E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} \log \frac{\frac{\tilde p(\x_l, \c^*)}{\tilde p_X(\x_l)}} {\sum_{\x' \in X} \frac{\tilde p(\x', \c^*)}{\tilde p_X(\x')}} \\ &= \E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} \log \frac{\frac{\tilde p(\x_l, \c^*)}{\tilde p_X(\x_l)} + \sum_{\x' \in X, \x' \ne x_l} \frac{\tilde p(\x', \c^*)}{\tilde p_X(\x')}} {\frac{\tilde p(\x_l, \c^*)}{\tilde p_X(\x_l)}} \\ &= \E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} \log \big( 1 + \frac{\tilde p_X(\x_l)} {\tilde p(\x_l, \c^*)} \sum_{\x' \in X, \x' \ne x_l} \frac{\tilde p(\x', \c^*)}{\tilde p_X(\x')} \big) \\ &\approx \E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} \log \big( 1 + \frac{\tilde p_X(\x_l)} {\tilde p(\x_l, \c^*)} (N - 1) \E_{\tilde p_X(\x')} \frac{\tilde p(\x', \c^*)}{\tilde p_X(\x')} \big) \\ &= \E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} \log \big( 1 + \frac{\tilde p_X(\x_l)} {\tilde p(\x_l, \c^*)} (N - 1) \big) \\ &\ge \E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} \log \big( \frac{\tilde p_X(\x_l)} {\tilde p(\x_l, \c^*)} (N - 1) \big) \\ &= \E_{p(\x_{1:N+1},\c^*)} \E_{p(l|\x_{1:N+1}, \c^*)} [\log \frac{\tilde p_X(\x_l)} {\tilde p(\x_l, \c^*)}] + \log (N - 1) \\ &= -I(\x;\c^*) + \log (N - 1) \end{aligned} \] Therefore, \(I(\x;\c^\star) \ge \log(N-1) - \mathcal L^{\mathrm{opt}}_{N}\).

Externals

Paper Review || CPC Formulation || NCE and InfoNCE || Demo of Bounding Mutual Information

Previous
Next