3 Coding Attention Mechanisms
3 编码注意力机制

This chapter covers 本章涵盖

  • Exploring the reasons for using attention mechanisms in neural networks
  • Introducing a basic self-attention framework and progressing to an enhanced self-attention mechanism
  • Implementing a causal attention module that allows LLMs to generate one token at a time
    实现一个因果注意力模块,允许 LLMs 一次生成一个令牌
  • Masking randomly selected attention weights with dropout to reduce overfitting
    用 dropout 掩盖随机选择的注意力权重以减少过度拟合
  • Stacking multiple causal attention modules into a multi-head attention module

In the previous chapter, you learned how to prepare the input text for training LLMs. This involved splitting text into individual word and subword tokens, which can be encoded into vector representations, the so-called embeddings, for the LLM.
在上一章中,您学习了如何准备训练输入文本LLMs。这涉及将文本拆分为单独的单词和子单词标记,这些标记可以编码为 LLM 的向量表示,即所谓的嵌入。

In this chapter, we will now look at an integral part of the LLM architecture itself, attention mechanisms, as illustrated in Figure 3.1.

Figure 3.1 A mental model of the three main stages of coding an LLM, pretraining the LLM on a general text dataset, and finetuning it on a labeled dataset. This chapter focuses on attention mechanisms, which are an integral part of an LLM architecture.
图 3.1 编码 LLM、在一般文本数据集上预训练 LLM 以及在标记数据集上对其进行微调的三个主要阶段的心理模型。本章重点介绍注意力机制,它是 LLM 架构的组成部分。

Attention mechanisms are a comprehensive topic, which is why we are devoting a whole chapter to it. We will largely look at these attention mechanisms in isolation and focus on them at a mechanistic level. In the next chapter, we will then code the remaining parts of the LLM surrounding the self-attention mechanism to see it in action and to create a model to generate text.
注意力机制是一个综合性的主题,这就是我们用整整一章来讨论它的原因。我们将主要孤立地看待这些注意力机制,并在机械层面上关注它们。在下一章中,我们将围绕自注意力机制对 LLM 的其余部分进行编码,以查看其实际情况并创建一个模型来生成文本。

Figure 3.2 The figure depicts different attention mechanisms we will code in this chapter, starting with a simplified version of self-attention before adding the trainable weights. The causal attention mechanism adds a mask to self-attention that allows the LLM to generate one word at a time. Finally, multi-head attention organizes the attention mechanism into multiple heads, allowing the model to capture various aspects of the input data in parallel.
图 3.2 该图描绘了我们将在本章中编码的不同注意机制,从自注意的简化版本开始,然后添加可训练权重。因果注意力机制为自注意力添加了一个掩码,允许 LLM 一次生成一个单词。最后,多头注意力将注意力机制组织成多个头,使模型能够并行捕获输入数据的各个方面。

These different attention variants shown in Figure 3.2 build on each other, and the goal is to arrive at a compact and efficient implementation of multi-head attention at the end of this chapter that we can then plug into the LLM architecture we will code in the next chapter.
图 3.2 中所示的这些不同的注意力变体相互构建,目标是在本章末尾实现紧凑且高效的多头注意力实现,然后我们可以将其插入 LLM我们将在下一章中编写架构。

3.1 The problem with modeling long sequences
3.1 长序列建模问题

Before we dive into the self-attention mechanism that is at the heart of LLMs later in this chapter, what is the problem with architectures without attention mechanisms that predate LLMs? Suppose we want to develop a language translation model that translates text from one language into another. As shown in Figure 3.3, we can't simply translate a text word by word due to the grammatical structures in the source and target language.
在我们深入探讨本章后面 LLMs 核心的自注意力机制之前,先了解一下 LLMs 之前没有注意力机制的架构有什么问题?假设我们想要开发一种语言翻译模型,将文本从一种语言翻译成另一种语言。如图 3.3 所示,由于源语言和目标语言的语法结构,我们不能简单地逐字翻译文本。

Figure 3.3 When translating text from one language to another, such as German to English, it's not possible to merely translate word by word. Instead, the translation process requires contextual understanding and grammar alignment.
图 3.3 将文本从一种语言翻译成另一种语言(例如德语翻译成英语)时,不可能仅仅逐字翻译。相反,翻译过程需要上下文理解和语法对齐。

To address the issue that we cannot translate text word by word, it is common to use a deep neural network with two submodules, a so-called encoder and decoder. The job of the encoder is to first read in and process the entire text, and the decoder then produces the translated text.

We already briefly discussed encoder-decoder networks when we introduced the transformer architecture in chapter 1 (section 1.4, Using LLMs for different tasks). Before the advent of transformers, recurrent neural networks (RNNs) were the most popular encoder-decoder architecture for language translation.
当我们在第 1 章中介绍 Transformer 架构时,我们已经简要讨论了编码器-解码器网络(第 1.4 节,使用 LLMs 执行不同的任务)。在 Transformer 出现之前,循环神经网络 (RNN) 是最流行的语言翻译编码器-解码器架构。

An RNN is a type of neural network where outputs from previous steps are fed as inputs to the current step, making them well-suited for sequential data like text. If you are unfamiliar with RNNs, don't worry, you don't need to know the detailed workings of RNNs to follow this discussion; our focus here is more on the general concept of the encoder-decoder setup.
RNN 是一种神经网络,其中先前步骤的输出被作为当前步骤的输入,这使得它们非常适合文本等序列数据。如果您不熟悉 RNN,请不要担心,您不需要了解 RNN 的详细工作原理即可跟随本讨论;我们这里的重点更多是编码器-解码器设置的一般概念。

In an encoder-decoder RNN, the input text is fed into the encoder, which processes it sequentially. The encoder updates its hidden state (the internal values at the hidden layers) at each step, trying to capture the entire meaning of the input sentence in the final hidden state, as illustrated in Figure 3.4. The decoder then takes this final hidden state to start generating the translated sentence, one word at a time. It also updates its hidden state at each step, which is supposed to carry the context necessary for the next-word prediction.
在编码器-解码器 RNN 中,输入文本被送入编码器,编码器按顺序对其进行处理。编码器在每一步都会更新其隐藏状态(隐藏层的内部值),试图在最终隐藏状态中捕获输入句子的完整含义,如图 3.4 所示。然后,解码器采用最终的隐藏状态开始生成翻译的句子,一次一个单词。它还在每一步更新其隐藏状态,该隐藏状态应该携带下一个单词预测所需的上下文。

Figure 3.4 Before the advent of transformer models, encoder-decoder RNNs were a popular choice for machine translation. The encoder takes a sequence of tokens from the source language as input, where a hidden state (an intermediate neural network layer) of the encoder encodes a compressed representation of the entire input sequence. Then, the decoder uses its current hidden state to begin the translation, token by token.
图 3.4 在 Transformer 模型出现之前,编码器-解码器 RNN 是机器翻译的流行选择。编码器将源语言中的标记序列作为输入,其中编码器的隐藏状态(中间神经网络层)对整个输入序列的压缩表示进行编码。然后,解码器使用其当前的隐藏状态开始逐个标记的翻译。

While we don't need to know the inner workings of these encoder-decoder RNNs, the key idea here is that the encoder part processes the entire input text into a hidden state (memory cell). The decoder then takes in this hidden state to produce the output. You can think of this hidden state as an embedding vector, a concept we discussed in chapter 2.

The big issue and limitation of encoder-decoder RNNs is that the RNN can't directly access earlier hidden states from the encoder during the decoding phase. Consequently, it relies solely on the current hidden state, which encapsulates all relevant information. This can lead to a loss of context, especially in complex sentences where dependencies might span long distances.

For readers unfamiliar with RNNs, it is not essential to understand or study this architecture as we will not be using it in this book. The takeaway message of this section is that encoder-decoder RNNs had a shortcoming that motivated the design of attention mechanisms.

3.2 Capturing data dependencies with attention mechanisms

Before transformer LLMs, it was common to use RNNs for language modeling tasks such as language translation, as mentioned previously. RNNs work fine for translating short sentences but don't work well for longer texts as they don't have direct access to previous words in the input.


One major shortcoming in this approach is that the RNN must remember the entire encoded input in a single hidden state before passing it to the decoder, as illustrated in Figure 3.4 in the previous section.

Hence, researchers developed the so-called Bahdanau attention mechanism for RNNs in 2014 (named after the first author of the respective paper), which modifies the encoder-decoder RNN such that the decoder can selectively access different parts of the input sequence at each decoding step as illustrated in Figure 3.5.

Figure 3.5 Using an attention mechanism, the text-generating decoder part of the network can access all input tokens selectively. This means that some input tokens are more important than others for generating a given output token. The importance is determined by the so-called attention weights, which we will compute later. Note that this figure shows the general idea behind attention and does not depict the exact implementation of the Bahdanau mechanism, which is an RNN method outside this book's scope.

Interestingly, only three years later, researchers found that RNN architectures are not required for building deep neural networks for natural language processing and proposed the original transformer architecture (discussed in chapter 1) with a self-attention mechanism inspired by the Bahdanau attention mechanism.

Self-attention is a mechanism that allows each position in the input sequence to attend to all positions in the same sequence when computing the representation of a sequence. Self-attention is a key component of contemporary LLMs based on the transformer architecture, such as the GPT series.

This chapter focuses on coding and understanding this self-attention mechanism used in GPT-like models, as illustrated in Figure 3.6. In the next chapter, we will then code the remaining parts of the LLM.

Figure 3.6 Self-attention is a mechanism in transformers that is used to compute more efficient input representations by allowing each position in a sequence to interact with and weigh the importance of all other positions within the same sequence. In this chapter, we will code this self-attention mechanism from the ground up before we code the remaining parts of the GPT-like LLM in the following chapter.
3.3 Attending to different parts of the input with self-attention

We'll now delve into the inner workings of the self-attention mechanism and learn how to code it from the ground up. Self-attention serves as the cornerstone of every LLM based on the transformer architecture. It's worth noting that this topic may require a lot of focus and attention (no pun intended), but once you grasp its fundamentals, you will have conquered one of the toughest aspects of this book and implementing LLMs in general.

Since self-attention can appear complex, especially if you are encountering it for the first time, we will begin by introducing a simplified version of self-attention in the next subsection. Afterwards, in section 3.4, we will then implement the self-attention mechanism with trainable weights, which is used in LLMs.

3.3.1 A simple self-attention mechanism without trainable weights

In this section, we implement a simplified variant of self-attention, free from any trainable weights, which is summarized in Figure 3.7. The goal of this section is to illustrate a few key concepts in self-attention before adding trainable weights next in section 3.4.

Figure 3.7 The goal of self-attention is to compute a context vector, for each input element, that combines information from all other input elements. In the example depicted in this figure, we compute the context vector z(2). The importance or contribution of each input element for computing z(2) is determined by the attention weights α21 to α2T. When computing z(2), the attention weights are calculated with respect to input element x(2) and all other inputs. The exact computation of these attention weights is discussed later in this section.

Figure 3.7 shows an input sequence, denoted as x, consisting of T elements represented as x(1) to x(T). This sequence typically represents text, such as a sentence, that has already been transformed into token embeddings, as explained in chapter 2.

For example, consider an input text like "Your journey starts with one step." In this case, each element of the sequence, such as x(1), corresponds to a d-dimensional embedding vector representing a specific token, like "Your." In Figure 3.7, these input vectors are shown as 3-dimensional embeddings.

In self-attention, our goal is to calculate context vectors z(i) for each element x(i) in the input sequence. A context vector can be interpreted as an enriched embedding vector.

To illustrate this concept, let's focus on the embedding vector of the second input element, x(2) (which corresponds to the token "journey"), and the corresponding context vector, z(2), shown at the bottom of Figure 3.7. This enhanced context vector, z(2), is an embedding that contains information about x(2) and all other input elements x(1) to x(T).

In self-attention, context vectors play a crucial role. Their purpose is to create enriched representations of each element in an input sequence (like a sentence) by incorporating information from all other elements in the sequence, as illustrated in Figure 3.7. This is essential in LLMs, which need to understand the relationship and relevance of words in a sentence to each other. Later, we will add trainable weights that help an LLM learn to construct these context vectors so that they are relevant for the LLM to generate the next token.

In this section, we implement a simplified self-attention mechanism to compute these weights and the resulting context vector one step at a time.

Consider the following input sentence, which has already been embedded into 3-dimensional vectors as discussed in chapter 2. We choose a small embedding dimension for illustration purposes to ensure it fits on the page without line breaks:

import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)

The first step of implementing self-attention is to compute the intermediate values ω, referred to as attention scores, as illustrated in Figure 3.8.

Figure 3.8 The overall goal of this section is to illustrate the computation of the context vector z(2) using the second input sequence, x(2) as a query. This figure shows the first intermediate step, computing the attention scores ω between the query x(2) and all other input elements as a dot product.

Figure 3.8 illustrates how we calculate the intermediate attention scores between the query token and each input token. We determine these scores by computing the dot product of the query, x(2), with every other input token:

query = inputs[1]  #A 
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)

The computed attention scores are as follows:

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In the next step, as shown in Figure 3.9, we normalize each of the attention scores that we computed previously.

Figure 3.9 After computing the attention scores ω21 to ω2T with respect to the input query x(2), the next step is to obtain the attention weights α21 to α2T by normalizing the attention scores.

The main goal behind the normalization shown in Figure 3.9 is to obtain attention weights that sum up to 1. This normalization is a convention that is useful for interpretation and for maintaining training stability in an LLM. Here's a straightforward method for achieving this normalization step:

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

As the output shows, the attention weights now sum to 1:

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)

In practice, it's more common and advisable to use the softmax function for normalization. This approach is better at managing extreme values and offers more favorable gradient properties during training. Below is a basic implementation of the softmax function for normalizing the attention scores:

def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

As the output shows, the softmax function also meets the objective and normalizes the attention weights such that they sum to 1:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

Note that this naive softmax implementation (softmax_naive) may encounter numerical instability problems, such as overflow and underflow, when dealing with large or small input values. Therefore, in practice, it's advisable to use the PyTorch implementation of softmax, which has been extensively optimized for performance:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())
In this case, we can see that it yields the same results as our previous softmax_naive function:
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

Now that we computed the normalized attention weights, we are ready for the final step illustrated in Figure 3.10: calculating the context vector z(2) by multiplying the embedded input tokens, x(i), with the corresponding attention weights and then summing the resulting vectors.

Figure 3.10 The final step, after calculating and normalizing the attention scores to obtain the attention weights for query x(2), is to compute the context vector z(2). This context vector is a combination of all input vectors x(1) to x(T) weighted by the attention weights.

The context vector z(2) depicted in Figure 3.10 is calculated as a weighted sum of all input vectors. This involves multiplying each input vector by its corresponding attention weight:

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i

The results of this computation are as follows:

tensor([0.4419, 0.6515, 0.5683])

In the next section, we will generalize this procedure for computing context vectors to calculate all context vectors simultaneously.

3.3.2 Computing attention weights for all input tokens

In the previous section, we computed attention weights and the context vector for input 2, as shown in the highlighted row in Figure 3.11. Now, we are extending this computation to calculate attention weights and context vectors for all inputs.

Figure 3.11 The highlighted row shows the attention weights for the second input element as a query, as we computed in the previous section. This section generalizes the computation to obtain all other attention weights.

We follow the same three steps as before, as summarized in Figure 3.12, except that we make a few modifications in the code to compute all context vectors instead of only the second context vector, z(2).

Figure 3.12 In self-attention, we begin by computing the attention scores, which are then normalized to obtain attention weights that sum up to 1. These attention weights are used to compute the context vectors as a weighted sum of the inputs.

First, in step 1 as illustrated in Figure 3.12, we add an additional for-loop to compute the dot products for all pairs of inputs.

attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

The resulting attention scores are as follows:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

Each element in the preceding tensor represents an attention score between each pair of inputs, as illustrated in Figure 3.11. Note that the values in Figure 3.11 are normalized, which is why they differ from the unnormalized attention scores in the preceding tensor. We will take care of the normalization later.

When computing the preceding attention score tensor, we used for-loops in Python. However, for-loops are generally slow, and we can achieve the same results using matrix multiplication:

attn_scores = inputs @ inputs.T

We can visually confirm that the results are the same as before:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In step 2, as illustrated in Figure 3.12, we now normalize each row so that the values in each row sum to 1:

attn_weights = torch.softmax(attn_scores, dim=1)

This returns the following attention weight tensor that matches the values shown in Figure 3.10:

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

Before we move on to step 3, the final step shown in Figure 3.12, let's briefly verify that the rows indeed all sum to 1:

row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=1))

The result is as follows:

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In the third and last step, we now use these attention weights to compute all context vectors via matrix multiplication:

all_context_vecs = attn_weights @ inputs

In the resulting output tensor, each row contains a 3-dimensional context vector:

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

We can double-check that the code is correct by comparing the 2nd row with the context vector z(2) that we computed previously in section 3.3.1:

print("Previous 2nd context vector:", context_vec_2)

Based on the result, we can see that the previously calculated context_vec_2 matches the second row in the previous tensor exactly:

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])

This concludes the code walkthrough of a simple self-attention mechanism. In the next section, we will add trainable weights, enabling the LLM to learn from data and improve its performance on specific tasks.

3.4 Implementing self-attention with trainable weights

In this section, we are implementing the self-attention mechanism that is used in the original transformer architecture, the GPT models, and most other popular LLMs. This self-attention mechanism is also called scaled dot-product attention. Figure 3.13 provides a mental model illustrating how this self-attention mechanism fits into the broader context of implementing an LLM.

Figure 3.13 A mental model illustrating how the self-attention mechanism we code in this section fits into the broader context of this book and chapter. In the previous section, we coded a simplified attention mechanism to understand the basic mechanism behind attention mechanisms. In this section, we add trainable weights to this attention mechanism. In the upcoming sections, we will then extend this self-attention mechanism by adding a causal mask and multiple heads.

As illustrated in Figure 3.13 the self-attention mechanism with trainable weights builds on the previous concepts: we want to compute context vectors as weighted sums over the input vectors specific to a certain input element. As you will see, there are only slight differences compared to the basic self-attention mechanism we coded earlier in section 3.3.

The most notable difference is the introduction of weight matrices that are updated during model training. These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce "good" context vectors. (Note that we will train the LLM in chapter 5.)

We will tackle this self-attention mechanism in the two subsections. First, we will code it step-by-step as before. Second, we will organize the code into a compact Python class that can be imported into an LLM architecture, which we will code in chapter 4.

3.4.1 Computing the attention weights step by step

We will implement the self-attention mechanism step by step by introducing the three trainable weight matrices Wq, Wk, and Wv. These three matrices are used to project the embedded input tokens, x(i), into query, key, and value vectors as illustrated in Figure 3.14.

Figure 3.14 In the first step of the self-attention mechanism with trainable weight matrices, we compute query (q), key (k), and value (v) vectors for input elements x. Similar to previous sections, we designate the second input, x(2), as the query input. The query vector q(2) is obtained via matrix multiplication between the input x(2) and the weight matrix Wq. Similarly, we obtain the key and value vectors via matrix multiplication involving the weight matrices Wk and Wv.

Earlier in section 3.3.1, we defined the second input element x(2) as the query when we computed the simplified attention weights to compute the context vector z(2). Later, in section 3.3.2, we generalized this to compute all context vectors z(1) ... z(T) for the six-word input sentence "Your journey starts with one step."

Similarly, we will start by computing only one context vector, z(2), for illustration purposes. In the next section, we will modify this code to calculate all context vectors.

Let's begin by defining a few variables:

x_2 = inputs[1] #A
d_in = inputs.shape[1] #B
d_out = 2 #C

Note that in GPT-like models, the input and output dimensions are usually the same, but for illustration purposes, to better follow the computation, we choose different input (d_in=3) and output (d=2) dimensions here.


Next, we initialize the three weight matrices Wq, Wk, and Wv that are shown in Figure 3.14:

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

Note that we are setting requires_grad=False to reduce clutter in the outputs for illustration purposes, but if we were to use the weight matrices for model training, we would set requires_grad=True to update these matrices during model training.

Next, we compute the query, key, and value vectors as shown earlier in Figure 3.14:

query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value

As we can see based on the output for the query, this results in a 2-dimensional vector since we set the number of columns of the corresponding weight matrix, via d_out, to 2:

tensor([0.4306, 1.4551])

Even though our temporary goal is to only compute the one context vector, z(2), we still require the key and value vectors for all input elements as they are involved in computing the attention weights with respect to the query q(2), as illustrated in Figure 3.14.

We can obtain all keys and values via matrix multiplication:

keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

As we can tell from the outputs, we successfully projected the 6 input tokens from a 3D onto a 2D embedding space:

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])

The second step is now to compute the attention scores, as shown in Figure 3.15.

Figure 3.15 The attention score computation is a dot-product computation similar to what we have used in the simplified self-attention mechanism in section 3.3. The new aspect here is that we are not directly computing the dot-product between the input elements but using the query and key obtained by transforming the inputs via the respective weight matrices.

First, let's compute the attention score ω22:

keys_2 = keys[1] #A
attn_score_22 = query_2.dot(keys_2)

The results in the following unnormalized attention score:


Again, we can generalize this computation to all attention scores via matrix multiplication:

attn_scores_2 = query_2 @ keys.T # All attention scores for given query

As we can see, as a quick check, the second element in the output matches attn_score_22 we computed previously:

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

The third step is now going from the attention scores to the attention weights, as illustrated in Figure 3.16.

Figure 3.16 After computing the attention scores ω, the next step is to normalize these scores using the softmax function to obtain the attention weights α.

Next, as illustrated in Figure 3.16, we compute the attention weights by scaling the attention scores and using the softmax function we used earlier.. The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension of the keys, (note that taking the square root is mathematically the same as exponentiating by 0.5):

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)

The resulting attention weights are as follows:

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

Now, the final step is to compute the context vectors, as illustrated in Figure 3.17.

Figure 3.17 In the final step of the self-attention computation, we compute the context vector by combining all value vectors via the attention weights.

Similar to section 3.3, where we computed the context vector as a weighted sum over the input vectors, we now compute the context vector as a weighted sum over the value vectors. Here, the attention weights serve as a weighting factor that weighs the respective importance of each value vector. Similar to section 3.3, we can use matrix multiplication to obtain the output in one step:

context_vec_2 = attn_weights_2 @ values

The contents of the resulting vector are as follows:

tensor([0.3061, 0.8210])

So far, we only computed a single context vector, z(2). In the next section, we will generalize the code to compute all context vectors in the input sequence, z(1) to z(T).

3.4.2 Implementing a compact self-attention Python class

In the previous sections, we have gone through a lot of steps to compute the self-attention outputs. This was mainly done for illustration purposes so we could go through one step at a time. In practice, with the LLM implementation in the next chapter in mind, it is helpful to organize this code into a Python class as follows:

Listing 3.1 A compact self-attention class
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In this PyTorch code, SelfAttention_v1 is a class derived from nn.Module, which is a fundamental building block of PyTorch models, which provides necessary functionalities for model layer creation and management.

The __init__ method initializes trainable weight matrices (W_query, W_key, and W_value) for queries, keys, and values, each transforming the input dimension d_in to an output dimension d_out.

During the forward pass, using the forward method, we compute the attention scores (attn_scores) by multiplying queries and keys, normalizing these scores using softmax. Finally, we create a context vector by weighting the values with these normalized attention scores.

We can use this class as follows:

sa_v1 = SelfAttention_v1(d_in, d_out)

Since inputs contains six embedding vectors, this result in a matrix storing the six context vectors:

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

As a quick check, notice how the second row ([0.3061, 0.8210]) matches the contents of context_vec_2 in the previous section.

Figure 3.18 summarizes the self-attention mechanism we just implemented.

Figure 3.18 In self-attention, we transform the input vectors in the input matrix X with the three weight matrices, Wq, Wk, and Wv. Then, we compute the attention weight matrix based on the resulting queries (Q) and keys (K). Using the attention weights and values (V), we then compute the context vectors (Z). (For visual clarity, we focus on a single input text with n tokens in this figure, not a batch of multiple inputs. Consequently, the 3D input tensor is simplified to a 2D matrix in this context. This approach allows for a more straightforward visualization and understanding of the processes involved.)

As shown in Figure 3.18, self-attention involves the trainable weight matrices Wq, Wk, and Wv. These matrices transform input data into queries, keys, and values, which are crucial components of the attention mechanism. As the model is exposed to more data during training, it adjusts these trainable weights, as we will see in upcoming chapters.

We can improve the SelfAttention_v1 implementation further by utilizing PyTorch's nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled. Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.

Listing 3.2 A self-attention class using PyTorch's Linear layers
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
        context_vec = attn_weights @ values
        return context_vec

You can use the SelfAttention_v2 similar to SelfAttention_v1. However, note that SelfAttention_v1 and SelfAttention_v2 give different outputs because they use different initial weights for the weight matrices since nn.Linear uses a more sophisticated weight initialization scheme.

In the next section, we will make enhancements to the self-attention mechanism, focusing specifically on incorporating causal and multi-head elements. The causal aspect involves modifying the attention mechanism to prevent the model from accessing future information in the sequence, which is crucial for tasks like language modeling, where each word prediction should only depend on previous words.

The multi-head component involves splitting the attention mechanism into multiple "heads." Each head learns different aspects of the data, allowing the model to simultaneously attend to information from different representation subspaces at different positions. This improves the model's performance in complex tasks.

3.5 Hiding future words with causal attention

In this section, we modify the standard self-attention mechanism to create a causal attention mechanism, which is essential for developing an LLM in the subsequent chapters.

Causal attention, also known as masked attention, is a specialized form of self-attention. It restricts a model to only consider previous and current inputs in a sequence when processing any given token. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once.

Consequently, when computing attention scores, the causal attention mechanism ensures that the model only factors in tokens that occur at or before the current token in the sequence.

Figure 3.19 In causal attention, we mask out the attention weights above the diagonal such that for a given input, the LLM can't access future tokens when computing the context vectors using the attention weights. For example, for the word "journey" in the second row, we only keep the attention weights for the words before ("Your") and in the current position ("journey").

As illustrated in Figure 3.19, we mask out the attention weights above the diagonal, and we normalize the non-masked attention weights, such that the attention weights sum to 1 in each row. In the next section, we will implement this masking and normalization procedure in code.

3.5.1 Applying a causal attention mask


In this section, we implement the causal attention mask in code. We start with the procedure summarized in Figure 3.20.

Figure 3.20 One way to obtain the masked attention weight matrix in causal attention is to apply the softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.

To implement the steps to apply a causal attention mask to obtain the masked attention weights as summarized in Figure 3.20, let's work with the attention scores and weights from the previous section to code the causal attention mechanism.

In the first step illustrated in Figure 3.20, we compute the attention weights using the softmax function as we have done in previous sections:

queries = sa_v2.W_query(inputs)  #A
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)

This results in the following attention weights:

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],

We can implement step 2 in Figure 3.20 using PyTorch's tril function to create a mask where the values above the diagonal are zero:

block_size = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(block_size, block_size))

The resulting mask is as follows:

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

Now, we can multiply this mask with the attention weights to zero out the values above the diagonal:

masked_simple = attn_weights*mask_simple

As we can see, the elements above the diagonal are successfully zeroed out:

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],

The third step in Figure 3.20 is to renormalize the attention weights to sum up to 1 again in each row. We can achieve this by dividing each element in each row by the sum in each row:

row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums

The result is an attention weight matrix where the attention weights above the diagonal are zeroed out and where the rows sum to 1:

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],

While we could be technically done with implementing causal attention at this point, we can take advantage of a mathematical property of the softmax function and implement the computation of the masked attention weights more efficiently in fewer steps, as shown in Figure 3.21.

Figure 3.21 A more efficient way to obtain the masked attention weight matrix in causal attention is to mask the attention scores with negative infinity values before applying the softmax function.

The softmax function converts its inputs into a probability distribution. When negative infinity values (-∞) are present in a row, the softmax function treats them as zero probability. (Mathematically, this is because e- approaches 0.)

We can implement this more efficient masking "trick" by creating a mask with 1's above the diagonal and then replacing these 1's with negative infinity (-inf) values:

mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)

This results in the following mask:

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],

Now, all we need to do is apply the softmax function to these masked results, and we are done:

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)

As we can see based on the output, the values in each row sum to 1, and no further normalization is necessary:

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],

We could now use the modified attention weights to compute the context vectors via context_vec = attn_weights @ values, as in section 3.4. However, in the next section, we first cover another minor tweak to the causal attention mechanism that is useful for reducing overfitting when training LLMs.

3.5.2 Masking additional attention weights with dropout

Dropout in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively "dropping" them out. This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units. It's important to emphasize that dropout is only used during training and is disabled afterward.

In the transformer architecture, including models like GPT, dropout in the attention mechanism is typically applied in two specific areas: after calculating the attention scores or after applying the attention weights to the value vectors.

Here, we will apply the dropout mask after computing the attention weights, as illustrated in Figure 3.22, because it's the more common variant in practice.

Figure 3.22 Using the causal attention mask (upper left), we apply an additional dropout mask (upper right) to zero out additional attention weights to reduce overfitting during training.

In the following code example, we use a dropout rate of 50%, which means masking out half of the attention weights. (When we train the GPT model in later chapters, we will use a lower dropout rate, such as 0.1 or 0.2.)

In the following code, we apply PyTorch's dropout implementation first to a 6×6 tensor consisting of ones for illustration purposes:

dropout = torch.nn.Dropout(0.5) #A
example = torch.ones(6, 6) #B

As we can see, approximately half of the values are zeroed out:

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 =2. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

Now, let's apply dropout to the attention weight matrix itself:


The resulting attention weight matrix now has additional elements zeroed out and the remaining ones rescaled:

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5132, 0.7482, 0.7386, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5646, 0.5592, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4357, 0.0000, 0.3378, 0.0000, 0.0000],
        [0.0000, 0.4065, 0.3991, 0.2999, 0.2320, 0.0000]])

Having gained an understanding of causal attention and dropout masking, we will develop a concise Python class in the following section. This class is designed to facilitate the efficient application of these two techniques.

3.5.3 Implementing a compact causal attention class

In this section, we will now incorporate the causal attention and dropout modifications into the SelfAttention Python class we developed in section 3.4. This class will then serve as a template for developing multi-head attention in the upcoming section, which is the final attention class we implement in this chapter.

But before we begin, one more thing is to ensure that the code can handle batches consisting of more than one input so that the CausalAttention class supports the batch outputs produced by the data loader we implemented in chapter 2.

For simplicity, to simulate such batch inputs, we duplicate the input text example:

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) #A

This results in a 3D tensor consisting of 2 input texts with 6 tokens each, where each token is a 3-dimensional embedding vector:

torch.Size([2, 6, 3])

The following CausalAttention class is similar to the SelfAttention class we implemented earlier, except that we now added the dropout and causal mask components as highlighted in the following code:

Listing 3.3 A compact causal attention class
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) #A
           torch.triu(torch.ones(block_size, block_size),
        ) #B
    def forward(self, x):
        b, num_tokens, d_in = x.shape #C
New batch dimension b
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2) #C
        attn_scores.masked_fill_(  #D
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

While all added code lines should be familiar from previous sections, we now added a self.register_buffer() call in the __init__ method. The use of register_buffer in PyTorch is not strictly necessary for all use cases but offers several advantages here. For instance, when we use the CausalAttention class in our LLM, buffers are automatically moved to the appropriate device (CPU or GPU) along with our model, which will be relevant when training the LLM in future chapters. This means we don't need to manually ensure these tensors are on the same device as your model parameters, avoiding device mismatch errors.

We can use the CausalAttention class as follows, similar to SelfAttention previously:

block_size = batch.shape[1]
ca = CausalAttention(d_in, d_out, block_size, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
The resulting context vector is a 3D tensor where each token is now represented by a 2D embedding:
context_vecs.shape: torch.Size([2, 6, 2])

Figure 3.23 provides a mental model that summarizes what we have accomplished so far.

Figure 3.23 A mental model summarizing the four different attention modules we are coding in this chapter. We began with a simplified attention mechanism, added trainable weights, and then added a casual attention mask. In the remainder of this chapter, we will extend the causal attention mechanism and code multi-head attention, which is the final module we will use in the LLM implementation in the next chapter.

As illustrated in Figure 3.23, in this section, we focused on the concept and implementation of causal attention in neural networks. In the next section, we will expand on this concept and implement a multi-head attention module that implements several of such causal attention mechanisms in parallel.

3.6 Extending single-head attention to multi-head attention

In this final section of this chapter, we are extending the previously implemented causal attention class over multiple-heads. This is also called multi-head attention.

The term "multi-head" refers to dividing the attention mechanism into multiple "heads," each operating independently. In this context, a single causal attention module can be considered single-head attention, where there is only one set of attention weights processing the input sequentially.

In the following subsections, we will tackle this expansion from causal attention to multi-head attention. The first subsection will intuitively build a multi-head attention module by stacking multiple CausalAttention modules for illustration purposes. The second subsection will then implement the same multi-head attention module in a more complicated but computationally more efficient way.

3.6.1 Stacking multiple single-head attention layers

In practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism (depicted earlier in Figure 3.18 in section 3.4.1), each with its own weights, and then combining their outputs. Using multiple instances of the self-attention mechanism can be computationally intensive, but it's crucial for the kind of complex pattern recognition that models like transformer-based LLMs are known for.

Figure 3.24 illustrates the structure of a multi-head attention module, which consists of multiple single-head attention modules, as previously depicted in Figure 3.24, stacked on top of each other.

Figure 3.24 The multi-head attention module in this figure depicts two single-head attention modules stacked on top of each other. So, instead of using a single matrix Wv for computing the value matrices, in a multi-head attention module with two heads, we now have two value weight matrices: Wv1 and Wv2. The same applies to the other weight matrices, Wq and Wk. We obtain two sets of context vectors Z1 and Z2 that we can combine into a single context vector matrix Z.

As mentioned before, the main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections -- the results of multiplying the input data (like the query, key, and value vectors in attention mechanisms) by a weight matrix.

In code, we can achieve this by implementing a simple MultiHeadAttentionWrapper class that stacks multiple instances of our previously implemented CausalAttention module:

Listing 3.4 A wrapper class to implement multi-head attention
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, block_size,
                 dropout, num_heads, qkv_bias=False):
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) 
             for _ in range(num_heads)]
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

For example, if we use this MultiHeadAttentionWrapper class with two attention heads (via num_heads=2) and CausalAttention output dimension d_out=2, this results in a 4-dimensional context vectors (d_out*num_heads=4), as illustrated in Figure 3.22.

Figure 3.25 Using the MultiHeadAttentionWrapper, we specified the number of attention heads (num_heads). If we set num_heads=2, as shown in this figure, we obtain a tensor with two sets of context vector matrices. In each context vector matrix, the rows represent the context vectors corresponding to the tokens, and the columns correspond to the embedding dimension specified via d_out=2. We concatenate these context vector matrices along the column dimension. Since we have 2 attention heads and an embedding dimension of 4, the final embedding dimension is 2 × 2 = 4.

To illustrate Figure 3.25 further with a concrete example, we can use the MultiHeadAttentionWrapper class similar to the CausalAttention class before:

block_size = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)
context_vecs = mha(batch)
print("context_vecs.shape:", context_vecs.shape)

This results in the following tensor representing the context vectors:

tensor([[[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]],
        [[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])

The first dimension of the resulting context_vecs tensor is 2 since we have two input texts (the input texts are duplicated, which is why the context vectors are exactly the same for those). The second dimension refers to the 6 tokens in each input. The third dimension refers to the 4-dimensional embedding of each token.

In this section, we implemented a MultiHeadAttentionWrapper that combined multiple single-head attention modules. However, note that these are processed sequentially via [head(x) for head in self.heads] in the forward method. We can improve this implementation by processing the heads in parallel. One way to achieve this is by computing the outputs for all attention heads simultaneously via matrix multiplication, as we will explore in the next section.

3.6.2 Implementing multi-head attention with weight splits

In the previous section, we created a MultiHeadAttentionWrapper to implement multi-head attention by stacking multiple single-head attention modules. This was done by instantiating and combining several CausalAttention objects.

Instead of maintaining two separate classes, MultiHeadAttentionWrapper and CausalAttention, we can combine both of these concepts into a single MultiHeadAttention class. Also, in addition to just merging the MultiHeadAttentionWrapper with the CausalAttention code, we will make some other modifications to implement multi-head attention more efficiently.

In the MultiHeadAttentionWrapper, multiple heads are implemented by creating a list of CausalAttention objects (self.heads), each representing a separate attention head. The CausalAttention class independently performs the attention mechanism, and the results from each head are concatenated. In contrast, the following MultiHeadAttention class integrates the multi-head functionality within a single class. It splits the input into multiple heads by reshaping the projected query, key, and value tensors and then combines the results from these heads after computing attention.

Let's take a look at the MultiHeadAttention class before we discuss it further:

Listing 3.5 An efficient multi-head attention class
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                 block_size, dropout, num_heads, qkv_bias=False):
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads #A
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) #B
        self.dropout = nn.Dropout(dropout)
             torch.triu(torch.ones(block_size, block_size), diagonal=1)
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x) #C
        queries = self.W_query(x) #C
        values = self.W_value(x) #C
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #D
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)  #D
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)  #D
        keys = keys.transpose(1, 2) #E
        queries = queries.transpose(1, 2) #E
        values = values.transpose(1, 2) #E
        attn_scores = queries @ keys.transpose(2, 3)  #F
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens] #G
        mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0) #H
        attn_scores.masked_fill_(mask_unsqueezed, -torch.inf) #I
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = (attn_weights @ values).transpose(1, 2) #J
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) #L
        return context_vec

Even though the reshaping (.view) and transposing (.transpose) of tensors inside the MultiHeadAttention class looks very complicated, mathematically, the MultiHeadAttention class implements the same concept as the MultiHeadAttentionWrapper earlier.

On a big-picture level, in the previous MultiHeadAttentionWrapper, we stacked multiple single-head attention layers that we combined into a multi-head attention layer. The MultiHeadAttention class takes an integrated approach. It starts with a multi-head layer and then internally splits this layer into individual attention heads, as illustrated in Figure 3.26.

Figure 3.26 In the MultiheadAttentionWrapper class with two attention heads, we initialized two weight matrices Wq1 and Wq2 and computed two query matrices Q1 and Q2 as illustrated at the top of this figure. In the MultiheadAttention class, we initialize one larger weight matrix Wq , only perform one matrix multiplication with the inputs to obtain a query matrix Q, and then split the query matrix into Q1 and Q2 as shown at the bottom of this figure. We do the same for the keys and values, which are not shown to reduce visual clutter.

The splitting of the query, key, and value tensors, as depicted in Figure 3.26, is achieved through tensor reshaping and transposing operations using PyTorch's .view and .transpose methods. The input is first transformed (via linear layers for queries, keys, and values) and then reshaped to represent multiple heads.

The key operation is to split the d_out dimension into num_heads and head_dim, where head_dim = d_out / num_heads. This splitting is then achieved using the .view method: a tensor of dimensions (b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, head_dim).

The tensors are then transposed to bring the num_heads dimension before the num_tokens dimension, resulting in a shape of (b, num_heads, num_tokens, head_dim). This transposition is crucial for correctly aligning the queries, keys, and values across the different heads and performing batched matrix multiplications efficiently.

To illustrate this batched matrix multiplication, suppose we have the following example tensor:

a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573], #A
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],
                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

Now, we perform a batched matrix multiplication between the tensor itself and a view of the tensor where we transposed the last two dimensions, num_tokens and head_dim:

print(a @ a.transpose(2, 3))

The result is as follows:

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],
         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])

In this case, the matrix multiplication implementation in PyTorch handles the 4-dimensional input tensor so that the matrix multiplication is carried out between the 2 last dimensions (num_tokens, head_dim) and then repeated for the individual heads.

For instance, the above becomes a more compact way to compute the matrix multiplication for each head separately:

first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)
second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

The results are exactly the same results that we obtained when using the batched matrix multiplication print(a @ a.transpose(2, 3)) earlier:

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])

Continuing with MultiHeadAttention, after computing the attention weights and context vectors, the context vectors from all heads are transposed back to the shape (b, num_tokens, num_heads, head_dim). These vectors are then reshaped (flattened) into the shape (b, num_tokens, d_out), effectively combining the outputs from all heads.

Additionally, we added a so-called output projection layer (self.out_proj) to MultiHeadAttention after combining the heads, which is not present in the CausalAttention class. This output projection layer is not strictly necessary (see the References section in Appendix B for more details), but it is commonly used in many LLM architectures, which is why we added it here for completeness.

Even though the MultiHeadAttention class looks more complicated than the MultiHeadAttentionWrapper due to the additional reshaping and transposition of tensors, it is more efficient. The reason is that we only need one matrix multiplication to compute the keys, for instance, keys = self.W_key(x) (the same is true for the queries and values). In the MultiHeadAttentionWrapper, we needed to repeat this matrix multiplication, which is computationally one of the most expensive steps, for each attention head.

The MultiHeadAttention class can be used similar to the SelfAttention and CausalAttention classes we implemented earlier:

batch_size, block_size, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)
context_vecs = mha(batch)
print("context_vecs.shape:", context_vecs.shape)

As we can see based on the results, the output dimension is directly controlled by the d_out argument:

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],
        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])

In this section, we implemented the MultiHeadAttention class that we will use in the upcoming sections when implementing and training the LLM itself. Note that while the code is fully functional, we used relatively small embedding sizes and numbers of attention heads to keep the outputs readable.

For comparison, the smallest GPT-2 model (117 million parameters) has 12 attention heads and a context vector embedding size of 768. The largest GPT-2 model (1.5 billion parameters) has 25 attention heads and a context vector embedding size of 1600. Note that the embedding sizes of the token inputs and context embeddings are the same in GPT models (d_in = d_out).

3.7 Summary

  • Attention mechanisms transform input elements into enhanced context vector representations that incorporate information about all inputs.
  • A self-attention mechanism computes the context vector representation as a weighted sum over the inputs.
  • In a simplified attention mechanism, the attention weights are computed via dot products.
  • A dot product is just a concise way of multiplying two vectors element-wise and then summing the products.
  • Matrix multiplications, while not strictly required, help us to implement computations more efficiently and compactly by replacing nested for-loops.
  • In self-attention mechanisms that are used in LLMs, also called scaled-dot product attention, we include trainable weight matrices to compute intermediate transformations of the inputs: queries, values, and keys.
  • When working with LLMs that read and generate text from left to right, we add a causal attention mask to prevent the LLM from accessing future tokens.
  • Next to causal attention masks to zero out attention weights, we can also add a dropout mask to reduce overfitting in LLMs.
  • The attention modules in transformer-based LLMs involve multiple instances of causal attention, which is called multi-head attention.
  • We can create a multi-head attention module by stacking multiple instances of causal attention modules.
  • A more efficient way of creating multi-head attention modules involves batched matrix multiplications.

