A Simpler Description of the Self-Attention Mechanism
This piece is based on this incredible video and this amazing in-depth article, so if you want to learn more, please visit these resources.
Introduction
It took me years to understand how the self-attention mechanism inside of the transformer architecture works, so I made this article as an explanation meant for people who actually want to understand how the self-attention mechanism works at a fairly low level without equations. This piece aims to give you understanding and intuition, not math.
The goal of the attention mechanism is to allow models to have better context for understanding their inputs. For example, if a model is processing the sentence “The bank on the side of the river”, it should be able to know that “bank” refers to a river bank, not a financial institution.
Words in the computer are represented as vectors, so we can multiply these vectors by matrices (which we learn over time). There are three main matrices used in the attention mechanism:
- Query matrix
- Key matrix
- Value matrix
These names loosely follow an analogy to databases, where you’re looking up values using a query or a key, but it doesn’t quite match, so don’t worry about the analogous terms.
In our example, we want the machine learning model to have some way to represent the word “bank” differently when it’s a river bank versus a money bank. Specifically, we want the model to be able to learn that the word “river” should influence its representation of the word “bank”.
The Algorithm
Below, I describe how the self-attention mechanism inside of the transformer architecture works. Each step listed below identifies what happens, and then poses questions that the algorithm is essentially trying to answer during that step.
This sequence of steps below is followed for every pair of words in the sentence (including pairs where both words are the same), where one of the words is assigned to be the query word (the word getting context) and the other word is assigned to be the key word (the word giving context). If there are five words in the sentence, each word will be the query word five times and each word will be the key word five times.
To highlight how this mechanism adds context, we use the example of “bank” being the query word, and “river” being the key word, to show how “river” can add context to “bank”.
- Multiply “bank” embedding vector by query matrix; get back query vector. (How should I represent the query word (“bank”) when trying to figure out what key words (like “river”) are important to get context from?)
- Multiply “river” embedding vector by key matrix; get back key vector. (How should we represent the key word (“river”) when trying to figure out what query words (like “bank”) are important to give context to?)
- Dot product query vector and key vector; get back attention scalar. (How important is the key word (“river”) for providing context to the query word (“bank”)? In other words, how much should the query word (“bank”) pay attention to the key word (“river”)? Remember that dot product essentially just computes the angle between two vectors; the closer the angle, the higher the attention scalar will be, and the more context information will be added from the key word (“river”) to the query word (“bank”).)
- (Normalize attention scalars — I discuss this below.)
- Multiply original “river” embedding vector by value matrix; get back value vector. (Assuming that the the key word (“river”) is important for providing context to the query word (“bank”), how should we represent that context information?)
- Multiply value vector by (normalized) attention scalar; get back scaled value vector. (If getting context from the key word (“river”) isn’t all that important [the attention scalar is low], then make the context’s representation small. If getting context from the key word (“river”) is very important [the attention scalar is high], make the context’s representation big.)
- Repeat steps 1–6 for each possible key word (not just “river” like we did above!) Then, add all of the scaled value vectors together; get back final contextualized vector. (We now have a contextualized embedding vector that we can use for the rest of our computations in place of the original query (“bank”) vector!)
And those are the main steps of the attention mechanism! There is only one query, one key, and one value matrix for all of the words to share. The values of those three matrices are learned throughout training.
If you understood all that, then you understand how the self-attention mechanism works! What follows are less important details of how the algorithm works.
Normalization
The normalization step exists because the attention scalars might get too big or small.
In order to do the attention mechanism correctly, when you’ve gotten the attention scalar from each key word (for a single fixed query word), you normalize all of them so that the attention scalars from all the words sum to 1. This essentially converts each of the attention scalars into a percentage of “in comparison to other words in this input, how important is it that this key word provides context?” Then, it’s that percentage that is multiplied by the value vector, rather than the original scalar.
(There’s also one more small piece to normalization mentioned in the “Side Notes” section at the bottom.)
Multiple Heads
The algorithm described above can be thought of as a transformer attention “head” that’s applied to each (query) word in the sentence to get a more contextualized embedding for that (query) word. However, we may want to learn multiple different contextualization patterns simultaneously.
To do this, we can just rerun the exact same algorithm but with different query, key, and value matrices. We’ll randomly initialize these matrices to different values, and when we train, they’ll likely settle on different values as well! Each rerun of this algorithm is called a transformer “head”, and they’re usually run in parallel (at the same time, rather than one after the other).
We can then merge (by concatenating) the final contextualized embeddings of all of the heads together into one big final contextualized embedding, which we can use for future computations. People usually choose the dimensions of the query, key, and value matrices such that when you concatenate all of the (small) final contextualized embeddings, you get a vector that’s the same size as the original embedding. That is, you set: (number of heads) * (size of final contextualized embedding from attention mechanism) = (size of original vector). See the linked article for details.
This is an extremely common technique; the main transformers paper uses 8 heads running in parallel.
Multiple Layers
Well, if after using the attention mechanism, we now just have another set of (contextualized) embeddings… why not do it all over again with these (contextualized) embeddings? If you decide to add context between your already contextualized embeddings by using another attention mechanism, that’s called adding a transformer layer.
This, too, is an extremely common technique. The main transformer paper uses 6 layers.
Why Transformers Work
Parallel computation. Essentially, a lot of this computation can be done in parallel, meaning that you can get more computation done in less time so your algorithm performs better.
Side Notes
Tokens
While I say that transformers work on words, I should be more clear that they work on tokens, which are usually parts of words, not entire words. So, in practice, the attention mechanism provides context to tokens, not full words.
Normalization
In addition, inside the normalization step, you also divide each attention scalar by the square root of the number of dimensions in the query (or key, or value — they’re all the same) vector, before doing the final normalization. (See the linked article for details.) The authors of the main paper do this because they believe the dot product values become so big that when you normalize them (using the softmax function), the neural network doesn’t learn well, because softmax has a very small slope (gradient) for very big values.
Position Embeddings
A lot of newer neural networks will have a position embedding that is added (not concatenated) to the original query vector so that the attention mechanism can take into account where words are in the sentence when trying to figure out how to add context. For more details on how this is done, see the linked article or the original BERT paper (which has a nice diagram).
Other Types of Attention
The reason that what I described above is called Self-Attention is because the query words came from the same text as the key and value words. In other types of attention, such as Cross-Attention, the query words can come from one sequence (the decoder), while the key and value words come from a second sequence (the encoder). (Note that the key and value vectors are generally always the same [and thus come from the same sequence]. This makes sense in terms of attention adding context — the key and value vector should be the same word, while the key matrix determines the word’s representation for finding out which query words need context, and the value matrix determines the word’s representation for adding context information.)
Update 2/11/2024: This article was updated to add a section on other types of attention and to clarify that the scaled value vector is not added to the original encoding vector (as was previously written). Instead, all scaled value vectors are summed together (without the original encoding vector) to create the final representation.