本文讲解的是现代大模型(如GPT-4、Gemini等)背后的核心架构——Transformer。与2017年之前的Transformer架构相比,现代大模型在细节上有很多改进和优化。我们将从整体架构、核心组件(如Embedding层、Multi-Head Attention、Feed-Forward Network等)以及训练技巧等方面进行详细讲解。
1772859262355

Embedding层

Transformer模型的输入首先通过一个Embedding层将离散的词汇转换为连续的向量表示。这个过程可以看作是将每个词映射到一个高维空间中,使得语义相似的词在这个空间中也相近。
输入:(batch_size, seq_length) -> 输出:(batch_size, seq_length, d_model)
分词器输出的tokenID是不能直接使用的,比如苹果的tokenID是12345,香蕉的tokenID是67890,这两个数字没有任何语义上的联系。Embedding层通过学习一个词汇表,将每个tokenID映射到一个d_model维的向量空间中,使得苹果和香蕉在这个空间中的表示可以捕捉到它们的语义关系。
具体的实现是通过一个查找表(lookup table)来完成的,这个查找表是一个大小为(vocab_size, d_model)的矩阵,其中vocab_size是词汇表的大小,d_model是每个词向量的维度。对于输入的每个tokenID,Embedding层会在这个查找表中找到对应的行,并将其作为输出。
查找表这个weight矩阵是模型训练过程中学习得到的,初始时通常是随机初始化的。随着训练的进行,模型会根据输入数据和任务目标不断调整这个矩阵中的值,使得最终的词向量能够更好地捕捉到词汇之间的语义关系。

1
2
3
4
5
6
7
8
9
10
11
12
class Embedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
super().__init__()

self.weight = nn.Parameter(
torch.empty(num_embeddings, embedding_dim, device=device, dtype=dtype)
)

torch.nn.init.trunc_normal_(self.weight, mean=0, std=1.0, a=-3, b=3)

def forward(self, token_ids):
return self.weight[token_ids]

transformer block

每个 Transformer 块接收形状为 (batch_size, sequence_length, d_model) 的张量作为输入,并输出完全相同形状的张量 。它主要由两个带有残差连接的子层(Sub-layers)组成 :

1. 因果多头自注意力子层 (Causal Multi-Head Self-Attention)

  • 前置归一化 :在进行核心计算前,首先对输入应用均方根层归一化(RMSNorm)。
  • 核心计算:执行因果多头自注意力计算。在这一步中,模型会使用因果掩码(Causal Masking)以防止模型关注到序列中未来的 Token 。同时,会在查询向量(Query)和键向量(Key)上应用旋转位置嵌入(RoPE),以此来注入相对位置信息 ^^。
  • 残差连接 :将归一化前的原始输入与注意力机制的输出直接相加。其数学表达形式为:y = x + MultiHeadSelfAttention(RMSNorm(x))

2. 位置前馈网络子层 (Position-Wise Feed-Forward Network)

  • 前置归一化 :与上一个子层类似,在进入前馈网络前,先对上一步的输出 y 应用 RMSNorm。
  • 核心计算 :将数据输入前馈网络。现代架构摒弃了原始的 ReLU 激活函数,转而采用了无偏置项(no biases)的 SwiGLU 机制(结合了 SiLU 激活函数与 GLU 门控机制)进行非线性特征变换 。
  • 残差连接 :将该子层的输入与前馈网络的输出再次相加,得出该 Transformer 块的最终输出结果 。

RMSNorm(Root Mean Square Layer Normalization)。

传统的layer normalization会计算输入张量的均值和标准差,并使用这些统计量来归一化输入。相比之下,RMSNorm 只计算输入张量的均方根值,并使用这个值来进行归一化。这种方法在实践中被证明能够提供与传统 Layer Normalization 相似的性能,同时减少了计算复杂度。

$$
LayerNorm(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta
$$

$$
RMSNorm(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \odot g
$$

  • $x$:输入张量。
  • $d$:特征维度(即 d_model)。
  • $\epsilon$:一个极小的常数(通常为 $10^{-5}$),防止分母为零。
  • $g$:可学习的增益参数(gain),初始值为全 $1$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5, device=None, dtype=None):
super().__init__()

self.eps = eps
self.gain = nn.Parameter(torch.ones(d_model, device=device, dtype=dtype))

def forward(self, x):
in_dtype = x.dtype
x = x.to(torch.float32)

rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
x = x / rms
x = x * self.gain

return x.to(in_dtype)

Linear层

$$
y = x W^T
$$

现代大模型(如 Llama、PaLM)发现,去掉线性层里的加法偏置项 $\beta$ 几乎不会降低模型能力,反而能提升训练速度并省下宝贵的显存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import math

class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, device=None, dtype=None):
"""
构造一个无偏置项的线性变换层。
参数:
in_features: 输入特征的维度
out_features: 输出特征的维度
"""
super().__init__()

# 1. 定义权重矩阵 (Weight Matrix)
# 为什么形状是 (out_features, in_features)?
# 这是为了在底层计算时(比如调用 cuBLAS 库),内存连续性更好,乘法效率更高。
# 必须用 nn.Parameter 包裹,让 PyTorch 知道这是一个需要计算梯度的参数。
self.weight = nn.Parameter(
torch.empty((out_features, in_features), device=device, dtype=dtype)
)

# 2. 参数初始化 (Initialization)
# 计算标准差: sqrt( 2 / (d_in + d_out) )
# 这种基于输入和输出维度的初始化方法,也叫 Glorot / Xavier 初始化,能有效防止梯度在传播时突然放大或缩小
std = math.sqrt(2.0 / (in_features + out_features))

# 使用截断正态分布填充空的张量
# 文档严格要求截断在 [-30, 30] 之间,防止产生极端的初始异常值
nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-3*std, b=3*std)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播函数。
参数:
x: 形状为 (*, in_features) 的张量,可以有任意多个前置批次维度
返回:
形状为 (*, out_features) 的张量
"""
# --- 方式一:使用 einsum (自带文档,维度极其清晰) ---
# ... 代表任意数量的前置维度 (比如 batch_size, sequence_length)
# i 代表 in_features, j 代表 out_features
# 运算逻辑:x 的最后一个维度 i 和 weight 的最后一个维度 i 发生点积收缩,输出维度 j
# return torch.einsum('... i, j i -> ... j', x, self.weight)

# --- 方式二:使用常规的矩阵乘法 @ (需要转置) ---
# x 的形状是 (..., in_features)
# self.weight.T 的形状是 (in_features, out_features)
# 结果自然就是 (..., out_features)
return x @ self.weight.T

init 中计算 std = math.sqrt(2.0 / (in_features + out_features)) 是深度学习历史上一个非常伟大的洞察(Glorot 初始化):

如果权重一开始的数值太大,经过几次矩阵乘法,输出的特征值就会爆炸成天文数字,也就是“梯度爆炸”。

如果权重一开始的数值太小,乘几次之后,特征值就全部变成了 $0$,也就是“梯度消失”。

这个特定的数学公式 $\sigma = \sqrt{\frac{2}{d_{in} + d_{out}}}$ 刚好能够保证:数据进入线性层前后的方差(分布的离散程度)保持一致。加上 [-30, 30] 的强制截断,相当于给模型在起跑线上穿上了一层防弹衣,让它在训练初期的步伐无比稳健。

Position-Wise Feed-Forward Network

位置前馈网络(Position-Wise Feed-Forward Network)是Transformer架构中的一个重要组件,位于每个Transformer块的第二个子层。它的主要作用是对每个位置的特征进行非线性变换,以增强模型的表达能力。它的经典结构是:线性变换(升维)→ 激活函数 → 线性变换(降维)。
FFN的三大核心作用

  1. 提供非线性能力(最根本作用)
  • 问题:自注意力机制本质是加权求和,属于线性操作
  • FFN的解决方案:通过激活函数(ReLU、GELU等)引入非线性
  • 重要性:没有非线性,无论堆叠多少层Transformer,整体仍等价于单层线性模型,无法学习复杂模式
  1. 对信息进行深度加工
  • 多头注意力的角色:让词语之间“开会讨论”,收集全局信息
  • FFN的角色:让每个词语“独立深度思考”,消化和提炼信息
  • 工作方式:对注意力层输出的每个位置表示进行独立、并行的深度处理
  1. 存储模型知识
  • 参数量对比:在标准Transformer中,FFN的参数量通常占全层参数的2/3以上
  • 作用:如同模型的“长期记忆”,存储从训练数据中学到的各种模式和规律
  • 工作方式:通过大量参数学习特征之间的复杂组合关系

激活函数(引入非线性):

  1. ReLU(Rectified Linear Unit)
  • 定义:$f(x) = \max(0, x)$
  1. GLU(Gated Linear Unit)
  • 定义:$GLU(x) = (xW_1) \odot \sigma(xW_2)$
  1. SiLU(Sigmoid Linear Unit,也叫 Swish)
  • 定义:$f(x) = x \cdot \sigma(x)$,其中 $\sigma(x)$ 是 Sigmoid 函数,即 $\sigma(x) = \frac{1}{1 + e^{-x}}$。
    1772862537077

此处我们实现的SwiGLU是SiLU和GLU的结合体,公式如下:
$FFN(x) = \text{SwiGLU}(x, W_1, W_2, W_3) = W_2(\text{SiLU}(W_1x) \odot W_3x)$
$W_1$ 和 $W_3$ 负责“升维”,把数据从 $d_{model}$ 膨胀到巨大的隐藏维度 $d_{ff}$( $\frac{8}{3} d_{model}$)。$W_2$ 负责“降维”,把处理完的高维特征重新压缩回 $d_{model}$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设我们在同一文件中已经有了之前写的 Linear 类
# from .basics import Linear

class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int):
"""
基于 SwiGLU 机制的位置前馈网络。
参数:
d_model: 模型的隐藏层维度 (输入和最终输出的维度)
d_ff: 前馈网络内部膨胀后的隐藏层维度 (通常是 d_model 的 8/3 倍左右)
"""
super().__init__()

# 1. 激活分支投影层 (对应公式里的 W1)
# 将维度从 d_model 放大到 d_ff
self.w1 = Linear(in_features=d_model, out_features=d_ff)

# 2. 门控分支投影层 (对应公式里的 W3)
# 同样将维度从 d_model 放大到 d_ff
self.w3 = Linear(in_features=d_model, out_features=d_ff)

# 3. 输出投影层 (对应公式里的 W2)
# 将融合后的特征从 d_ff 缩小回 d_model
self.w2 = Linear(in_features=d_ff, out_features=d_model)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
参数:
x: 形状为 (batch_size, seq_len, d_model) 的张量
返回:
形状相同为 (batch_size, seq_len, d_model) 的张量
"""
# --- 第一步:激活分支 ---
# x 经过 w1 映射后,应用 SiLU (Swish) 激活函数
# SiLU 的数学定义是 x * sigmoid(x),它比 ReLU 平滑得多,梯度永远不会完全消失
branch1 = F.silu(self.w1(x))

# --- 第二步:门控分支 ---
# x 经过 w3 映射,纯线性,不加任何激活函数
branch2 = self.w3(x)

# --- 第三步:门控融合 (Gating) ---
# 两个分支逐元素相乘 (element-wise multiplication)
# branch2 就像是一把锁,通过数值的大小来控制 branch1 里哪些非线性特征可以流向下游
gated = branch1 * branch2

# --- 第四步:输出投影 ---
# 将维度为 d_ff 的融合特征映射回 d_model
return self.w2(gated)

RoPE 旋转位置嵌入

在经典的 Transformer 中,位置编码是绝对的:模型强行给第 3 个词加上一个“我是编号 3”的向量。
但在自然语言中,词与词之间的相对距离往往比绝对位置更重要。比如“我吃苹果”,无论这句话出现在文章的开头还是结尾,“吃”和“苹果”的距离都是 1。
RoPE 的伟大之处在于:它在输入时只注入绝对位置信息,但经过注意力机制的内积运算后,模型神奇地只感受到了相对位置信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn


class RotaryPositionalEmbedding(nn.Module):

def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
super().__init__()

self.d_k = d_k

# 计算 inverse frequencies
inv_freq = 1.0 / (
theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k)
)

# positions
positions = torch.arange(max_seq_len, device=device).float()

# outer product
freqs = torch.outer(positions, inv_freq)

cos = torch.cos(freqs)
sin = torch.sin(freqs)

# register buffers (不会被训练)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)

def _rotate_half(self, x):

x_even = x[..., ::2]
x_odd = x[..., 1::2]

x_rot = torch.stack((-x_odd, x_even), dim=-1)

return x_rot.flatten(-2)

def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
"""
x shape: (..., seq_len, d_k)
token_positions shape: (..., seq_len)
"""

# 确保索引类型正确
token_positions = token_positions.long()

cos = self.cos[token_positions]
sin = self.sin[token_positions]

# 当 x 是 (batch, heads, seq_len, d_k) 而 token_positions 是 (batch, seq_len) 时,
# 在 heads 维上插入一个维度,方便广播
while cos.ndim < x.ndim:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)

# 扩展到 d_k 维
cos = torch.repeat_interleave(cos, 2, dim=-1)
sin = torch.repeat_interleave(sin, 2, dim=-1)

return x * cos + self._rotate_half(x) * sin

缩放点积注意力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import math

def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
实现缩放点积注意力。

参数:
query: 形状为 (batch_size, ..., seq_len, d_k)
key: 形状为 (batch_size, ..., seq_len, d_k)
value: 形状为 (batch_size, ..., seq_len, d_v)
mask: 形状为 (seq_len, seq_len) 的布尔张量 (可选)。
True 表示参与注意力计算,False 表示该位置概率应为 0。

返回:
output: 形状为 (batch_size, ..., seq_len, d_v) 的张量
"""
# 1. 获取 d_k 的维度大小,用于后续缩放
# 无论前面有多少个 batch 维度,d_k 永远在最后一个维度
d_k = query.size(-1)

# 2. 计算 Q 和 K^T 的点积得分
# transpose(-2, -1) 会精确定位并翻转最后两个维度 (seq_len, d_k -> d_k, seq_len)
# `@` 运算符会自动忽略前面的所有的 `...` 批次维度,只对最后两维做矩阵乘法
# scores 形状: (batch_size, ..., seq_len, seq_len)
scores = query @ key.transpose(-2, -1)

# 3. 缩放分数
scores = scores / math.sqrt(d_k)

# 4. 应用掩码 (Masking)
if mask is not None:
# 题目要求:mask 为 False 的地方,最终概率必须为 0。
# 我们利用 masked_fill 把 mask == False 的位置强行填入负无穷 (-inf)
# PyTorch 的广播机制会自动把 (seq_len, seq_len) 的 mask 对齐到 scores 的最后两维
scores = scores.masked_fill(mask == False, float('-inf'))

# 5. 计算 Softmax,得到注意力概率权重
# PyTorch 底层的 F.softmax 已经自带了“减去最大值”的数值稳定保护,直接调用即可
# attn_weights 形状: (batch_size, ..., seq_len, seq_len)
attn_weights = torch.softmax(scores, dim=-1)

# 6. 将注意力权重乘以 Value 矩阵
# (batch_size, ..., seq_len, seq_len) @ (batch_size, ..., seq_len, d_v)
# output 形状: (batch_size, ..., seq_len, d_v)
output = attn_weights @ value

return output

上一层pre-norm输入的tensor形状是(batch_size, seq_len, d_model),query、key、value的形状都是(batch_size, seq_len, d_k),(实际上value的最后一维可以不一样,但是默认一样)其中d_k通常等于d_model // num_heads。
自注意力公式:

$$
Attention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
$$

先计算Q和K的点积,得到一个形状为(batch_size,…, seq_len, seq_len)的分数矩阵。然后将这个分数矩阵除以$\sqrt{d_k}$进行缩放,最后通过softmax函数得到注意力权重,再乘以V得到最终的输出。

因果掩码处理,如果是训练阶段,mask是一个下三角矩阵,确保模型只能关注到当前词和之前的词;如果是推理阶段,mask会根据已经生成的序列动态调整,保证模型不会看到未来的词。

softmax函数将最后一维的分数转换为概率分布attn_weights,确保最后一维所有权重加起来等于1。

最后乘以V的操作是将每个位置的值向量根据注意力权重进行加权求和,得到每个位置的最终输出表示(batch_size, …, seq_len, d_v)。

假设 seq_len 是 3(一共3个词)。

对于第 1 个词,attn_weights 的第一行可能是 [0.2, 0.7, 0.1]

当这行数字去乘以 $V$ 矩阵时,它实际上是在做 加权求和

  • 拿出 $V$ 里面第 1 个词特征的 20%
  • 加上 $V$ 里面第 2 个词特征的 70%
  • 加上 $V$ 里面第 3 个词特征的 10%
    融合成一个新的长度为 d_v 的向量,这就是第 1 个词经过注意力机制洗礼后,最终输出的特征!

multihead_self_attention

多头自注意力机制的核心思想是:将输入的特征空间分成多个子空间(head),每个子空间独立地进行注意力计算,最后再将这些子空间的输出拼接起来,经过线性变换得到最终的输出。这种设计允许模型在不同的子空间中捕捉到不同类型的关系和模式,从而增强了模型的表达能力。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from einops import rearrange

class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int,rope_theta=10000,max_seq_len=2048):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.num_heads = num_heads
self.d_k = d_model // num_heads

# 定义三组线性层,用于“捏”出 Q, K, V
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)

self.rope = RotaryPositionalEmbedding(
rope_theta,
self.d_k,
max_seq_len,
)

# 最后的融合输出层
self.W_o = nn.Linear(d_model, d_model, bias=False)

def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.shape

# 1. 投影并分头 (Split into heads)
# 结果形状: (batch, heads, seq, d_k)
q = rearrange(self.W_q(x), 'b s (h d) -> b h s d', h=self.num_heads)
k = rearrange(self.W_k(x), 'b s (h d) -> b h s d', h=self.num_heads)
v = rearrange(self.W_v(x), 'b s (h d) -> b h s d', h=self.num_heads)

# 注意:RoPE 只改变 Q 和 K 的值,不改变形状
q = self.rope(q, token_positions)
k = self.rope(k, token_positions)


# 2. 准备因果掩码 (下三角阵)
# 注意:这里我们只生成一个 (seq, seq) 的 2D 掩码
mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))

# 3. 直接调用写好的底层函数!
# 这个函数会处理所有的缩放、掩码填充、Softmax 和加权求和
# context 形状: (batch, heads, seq, d_v)
context = scaled_dot_product_attention(q, k, v, mask=mask)

# 4. 合并多头并输出
# 将 h 和 d 重新乘回 d_model: (batch, seq, d_model)
out = rearrange(context, 'b h s d -> b s (h d)')
return self.W_o(out)

学完上面线性层,位置前馈网络,RoPE,缩放点积注意力,我们就已经掌握了Transformer架构的核心组件。接下来我们会把这些组件组合起来,构建完整的Transformer块。
1772938048016

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, max_seq_len):
"""
参数:
d_model: 输入和输出的维度
num_heads: 多头注意力的头数
d_ff: 前馈网络 (SwiGLU) 的中间隐藏层维度
"""
super().__init__()

# 1. 实例化两个归一化层 (Pre-norm 模式)
# 注意:这里应使用你之前实现的 RMSNorm
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)

# 2. 实例化注意力子层
# 负责 Token 之间的信息交换
self.attn = MultiHeadSelfAttention(d_model=d_model, num_heads=num_heads, max_seq_len=max_seq_len)

# 3. 实例化前馈网络子层
# 负责每个 Token 内部的非线性特征提取 (SwiGLU)
self.ffn = FeedForward(d_model=d_model, d_ff=d_ff)

def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
"""
参数:
x: 形状为 (batch_size, seq_len, d_model)
token_positions: 形状为 (batch_size, seq_len)
返回:
形状相同的输出张量
"""
# --- 子层 1: 注意力机制 + 残差连接 ---
# 步骤:x -> RMSNorm -> Attention -> + 原始 x
# 这里的加法就是“干净的残差流”,梯度可以直接流向更早的层
x = x + self.attn(self.norm1(x), token_positions)

# --- 子层 2: 前馈网络 + 残差连接 ---
# 步骤:x -> RMSNorm -> FFN -> + 进入该层前的 x
x = x + self.ffn(self.norm2(x))

return x

transformer LM

将多个Transformer块堆叠起来,再加上后面几个层级组件(如输出层),我们就得到了一个完整的Transformer语言模型。
先不实现图中最后一个softmax层,因为在训练阶段我们通常会把softmax和交叉熵损失函数合并成一个更高效的函数(如nn.CrossEntropyLoss),而在推理阶段,我们也会直接对输出的logits进行采样或贪心解码,而不是先计算概率分布。
1772937987730

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class TransformerLM(nn.Module):

def __init__(
self,
vocab_size,
context_len,
d_model,
num_layers,
num_heads,
d_ff,
):
super().__init__()

self.embedding = Embedding(vocab_size, d_model)

self.layers = nn.ModuleList([
TransformerBlock(
d_model,
num_heads,
d_ff,
context_len,
)
for _ in range(num_layers)
])

self.norm = RMSNorm(d_model)

self.lm_head = Linear(d_model, vocab_size)

def forward(self, tokens):

B, T = tokens.shape

positions = torch.arange(T, device=tokens.device)
positions = positions.unsqueeze(0).expand(B, T)

x = self.embedding(tokens)

for layer in self.layers:
x = layer(x, positions)

x = self.norm(x)

return self.lm_head(x)

cs336_assignment_1中尚未学习的有:Adam优化器、余弦退火学习率、梯度裁剪、checkpoint.RoPE也还没学会