How TensorFlow's skip-gram model predicts words: An Intuitive Understanding of Conditional Probability and NCE Loss

Photo by Glen Carrie on Unsplash

How TensorFlow's skip-gram model predicts words: An Intuitive Understanding of Conditional Probability and NCE Loss

Table of contents

No heading

No headings in the article.

Imagine you're a detective trying to solve a mystery. You come across a series of messages and you need to figure out the meaning behind them. The messages are similar to words in a sentence and you want to understand how they relate to each other. You know that when people talk to each other they provide context to what they're saying. For instance, if someone says "I'm going to the store to buy some ice cream", you know that person is probably hungry. You want to use this same logic to understand how words are related in sentences.

To do this, you came across the skip-gram model. The skip-gram model takes in a word and tries to predict the words around it. It, therefore, tries to predict the context of that word. For example, if the word is "jumped," the context words around it might be "brown" and "fox." The model tries to predict these context words using the target word.

To make this task easier and faster, the skip-gram model uses a modified objective function called noise-contrastive estimation (NCE) loss, which is based on conditional probability. Essentially, this objective function selects a limited number of noisy examples of the target word and then tries to predict them. Based on their predictions, it updates the embedding vectors for each word. These vectors contain information about each word's context and it is this information that allows the skip-gram model to predict words and their relationships.

To understand this better, let's continue with the detective analogy. Imagine you're given a message that reads "The quick brown fox jumped over the lazy dog". You want to understand how the words in this message relate to each other. Using the skip-gram model, you would form a dataset of words and the contexts in which they appear. You can choose to define 'context' in any way that makes sense, but let's stick to the window of words to the left and to the right of a target word. Using a window size of 1, we then have the dataset of (context, target) pairs. Depending on your choice of (context, target) pairs, you can invert contexts and targets and try to predict each context word from its target word or try to predict each target word from its context nodes. We stick with the former for simplicity.

Let's say at a certain step, we select the example "predict the from quick". We select noise examples by drawing from some noise distribution, typically the unigram distribution, which selects single words based on their frequency. For this example, let's say the noise example we select is "sheep". Next, we compute the loss for this pair of observed and noisy examples using the conditional probability formula. The goal is to maximize this objective function. We do this by deriving the gradient of the loss with respect to the embedding parameters using TensorFlow's tf.nn.nce_loss() helper function. We then perform an update to the embedding vectors by taking a small step in the direction of the gradient using stochastic gradient descent (SGD). Over time, this process moves the embedding vectors for each word until the model is successful at discriminating real words from noise words.

We can visualize the learned vectors by projecting them down to 2 dimensions using something like the t-SNE dimensionality reduction technique. When we inspect these visualizations, it becomes apparent that the vectors capture some general and useful semantic information about words and their relationships to one another. We can analyze the vectors in the induced vector space and specialize towards certain semantic relationships, e.g. male-female, verb tense, and country-capital relationships between words. This is illustrated in different papers, including Mikolov et al. (2013)).

To sum it up, the skip-gram model uses context to predict words and uses NCE loss based on conditional probability to update the embedding vectors. Over time, this allows the model to learn the relationships between words and eventually capture semantic information about them.