卡卡编程网

专注编程技术分享,涵盖开发教程与实战案例

DeepSeek-R1源码解读

最近和开发者做了很多DeepSeek-R1模型相关的推理项目,这两天抽时间把hugging face上面的源码拉下来仔细看了一遍,在这里做一个分享。主要是解析MOE部分的代码,包括EP并行的代码实现。

整体结构

查看hugging face上面的modeling_deepseek.py文件和config.json文件,可以发现代码结构和DeepSeek-V3是完全相同的。DeepseekV3DecoderLayer类的forward函数如下:

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    **kwargs,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
    outputs += (self_attn_weights,)
if use_cache:
    outputs += (present_key_value,)
return outputs

这是非常标准的transformer模型结构,由input_layernorm、attention、Fully Connected部分组成。DeepSeek最大的特点是Fully Connected使用了MOE结构,也就是代码中的self.mlp,调用的是DeepseekV3MoE类。DeepseekV3MoE的forward函数如下:

identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if not self.training:
    y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
    y = y + self.shared_experts(identity)

上面代码的核心部分是self.gate和self.moe_infer。下面我们将详细看一下这2部分。如果你对MOE的原理不熟悉,建议先看一下之前我写的这篇文章,有助于对下面内容的理解:AI布道Mr.Jin:DeepSeek模型MOE结构代码详解 。

Gate源码解读

Gate的作用是为输入的token选择合适的expert进行计算,并且把expert的权重也计算出来。DeepSeek-R1的gate代码如下(可以复制运行):

import numpy as np 
import torch      
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import torch.distributed as dist
from transformers.activations import ACT2FN

batch_size = 2  # 输入的batch_size
seq_len = 16  # 输入的seq_len
hidden_dim = 10  # attention输出的隐藏层维度
n_routed_experts = 16  # 专家总数
top_k = 3  # 每个token选择3个专家进行计算
n_group = 4  # 把专家分成4组进行处理
topk_group = 2  # 选择得分前2的专家组进行处理
# 初始化输入
hidden_states = torch.Tensor(np.random.random((batch_size, seq_len, hidden_dim)))
# 初始化gate的权重
weight = torch.Tensor(np.random.random((n_routed_experts, hidden_dim)))
e_score_correction_bias = torch.Tensor(np.random.random(n_routed_experts))

bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
# [bsz, seq_len, n_routed_experts]
logits = F.linear(
    hidden_states.type(torch.float32), weight.type(torch.float32), None
)
# 得到各token路由到各专家的概率
scores = logits.sigmoid()

topk_method = "noaux_tc"
### select top-k experts
# [bsz*seq_len, n_routed_experts]
scores_for_choice = scores.view(bsz * seq_len, -1) + e_score_correction_bias.unsqueeze(0)
print("scores_for_choice: ", scores_for_choice)
# 对专家进行分组,计算各组专家得分之和
group_scores = (
    scores_for_choice.view(bsz * seq_len, n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
)  # [n, n_group],n就是batch_size*seq_len
print("group_scores: ", group_scores)

# 每个token选择 topk_group 里面的专家作为候选专家
group_idx = torch.topk(
    group_scores, k=topk_group, dim=-1, sorted=False
)[
    1
]  # [n, top_k_group]
print("group_idx: ", group_idx)

group_mask = torch.zeros_like(group_scores)  # [n, n_group]
group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
# 获取score_mask,shape为[n, n_routed_experts],对于每个token,被选中的专家组的专家位置的值都是1,否则都是0
score_mask = (
    group_mask.unsqueeze(-1)
    .expand(
        bsz * seq_len, n_group, n_routed_experts // n_group
    )
    .reshape(bsz * seq_len, -1)
)  # [n, e]
print("score_mask: ", score_mask)

# 把没有选中的专家的分数置为-inf
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf"))  # [n, e]
print("tmp_scores: ", tmp_scores)

# 得到选中的专家id
_, topk_idx = torch.topk(
    tmp_scores, k=top_k, dim=-1, sorted=False
)
print("topk_idx: ", topk_idx)

topk_weight = scores.gather(1, topk_idx)
print("topk_weight: ", topk_weight)

### norm gate to sum 1
norm_topk_prob = True
routed_scaling_factor = 2.5
if top_k > 1 and norm_topk_prob:
    denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
    topk_weight = topk_weight / denominator
topk_weight = topk_weight * routed_scaling_factor # must multiply the scaling factor
print(topk_idx.shape, topk_weight.shape)

可以看到,上面的处理流程和我之前分享的文章中的代码很相似,只是多了一个分组操作,可以加快专家选择的速度。

接下来看一下选好专家后,如何计算。

MOE推理源码解读

MOE推理源码如下:

# 定义每个专家的结构
moe_intermediate_size = 5
class DeepseekV3MLP(nn.Module):
    def __init__(self, hidden_size=None, intermediate_size=None):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN["silu"]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
# 构建专家组
experts = nn.ModuleList(
    [
        DeepseekV3MLP(hidden_dim, moe_intermediate_size)
        for i in range(n_routed_experts)
    ]
)

x = hidden_states  # [bsz*seq_len, hidden_dim]
topk_ids = topk_idx  # [bsz*seq_len, top_k]
cnts = topk_ids.new_zeros((topk_ids.shape[0], n_routed_experts))  # [bsz*seq_len, n_routed_experts]
print("cnts: ", cnts)
# cnts记录每个token的专家路由情况
cnts.scatter_(1, topk_ids, 1)  # [bsz*seq_len, n_routed_experts]
print("cnts: ", cnts)
# 统计每个专家的token数量
tokens_per_expert = cnts.sum(dim=0)  # [n_routed_experts]
# 按照expert编号的顺序,把每个expert对应的token下标取出来
idxs = topk_ids.view(-1).argsort()
print("idxs: ", idxs)
# 按照expert编号的顺序,把每个expert需要处理的token特征取出来
sorted_tokens = x[idxs // topk_ids.shape[1]]
print("sorted_tokens: ", sorted_tokens)

sorted_tokens_shape = sorted_tokens.shape
# 这个脚本可以在单卡上运行
ep_size = 1
# 所有专家都放在一个卡上
experts_per_rank = n_routed_experts
print("tokens_per_expert.shape[0]: ", tokens_per_expert.shape[0])
# 多卡EP并行场景
if ep_size > 1:
    # [ep_size, n_routed_experts // ep_size]->[ep_size]
    tokens_per_ep_rank = tokens_per_expert.view(ep_size, -1).sum(dim=1)
    # [n_routed_experts]
    tokens_per_expert_group = tokens_per_expert.new_empty(
        tokens_per_expert.shape[0]
    )
    dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)  # tokens_per_expert_group获取的是各个rank上分给本rank的token情况
    # [ep_size, n_routed_experts // ep_size] -> [ep_size]
    output_splits = (
        tokens_per_expert_group.view(ep_size, -1)
        .sum(1)
        .cpu()
        .numpy()
        .tolist()
    )
    # [total_token_on_this_rank, hidden_dim], 存储所有需要在本rank上计算的Token。
    gathered_tokens = sorted_tokens.new_empty(
        tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
    )
    input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
    # gathered_tokens记录了所有需要在本rank上计算的Token
    dist.all_to_all(
        list(gathered_tokens.split(output_splits)),
        list(sorted_tokens.split(input_split_sizes)),
    )
    # [experts_per_rank,], 记录的是所有节点发送给本rank上各expert的token数量,[expert1_token_num, expert2_token_num,...]
    tokens_per_expert_post_gather = tokens_per_expert_group.view(
        ep_size, experts_per_rank
    ).sum(dim=0)  
    gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
    s = 0
    for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
        # 记录每个token对应的expert编号
        gatherd_idxs[s : s + k] = i % experts_per_rank
        s += k
    gatherd_idxs = gatherd_idxs.argsort()  # [expert_total_token_num,]
    sorted_tokens = gathered_tokens[gatherd_idxs]  # [expert_total_token_num, hidden_dim]
    tokens_per_expert = tokens_per_expert_post_gather  # [experts_per_rank,]
tokens_per_expert = tokens_per_expert.cpu().numpy()
print("tokens_per_expert: ", tokens_per_expert)

outputs = []
start_idx = 0
ep_rank = 0
# 遍历每个专家进行计算
for i, num_tokens in enumerate(tokens_per_expert):
    end_idx = start_idx + num_tokens
    if num_tokens == 0:
        continue
    expert = experts[i + ep_rank * experts_per_rank]
    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
    expert_out = expert(tokens_for_this_expert)
    outputs.append(expert_out)
    start_idx = end_idx

# 把所有专家的计算结果concate起来
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
print("outs: ", outs)  # [bsz*seq_len*top_k, hidden_dim]

# EP并行情况下,需要把其他rank上的序列token在本rank上计算的结果返回
if ep_size > 1:
    new_x = torch.empty_like(outs)
    # 把输出按照原来的顺序排列,即各rank给本rank发送的token顺序
    new_x[gatherd_idxs] = outs
    gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
    dist.all_to_all(
        list(gathered_tokens.split(input_split_sizes)),
        list(new_x.split(output_splits)),
    )
    outs = gathered_tokens

new_x = torch.empty_like(outs)
# 把outs的顺序进行重排,让从上到下是按token的顺序进行排列
new_x[idxs] = outs
# 把每个token的多个expert处理结果进行加权求和
final_out = (
    new_x.view(*topk_ids.shape, -1)
    .type(topk_weight.dtype)
    .mul_(topk_weight.unsqueeze(dim=-1))
    .sum(dim=1)
    .type(new_x.dtype)
)
print("final_out: ", final_out)

上面的代码可以直接运行,模拟单卡场景(ep_size=1)的计算流程。对于ep_size>1的情况,主要有以下不同:

1,在EP并行的情况下,各个rank都会收到输入序列,该序列通过gate后,部分token需要发送给其他rank上面的expert进行处理;

2,各rank收到其他rank发送过来的token后,进行计算,计算完成后,需要把结果发送回原来的rank。

所以和单卡场景相比,EP场景下新增了3个分布式通信行为。

第一次通信行为是各个设备之间进行token数量的交换,使用的是all_to_all_single()方法。举个例子,假如EP并行度为4,在0卡、1卡、2卡、3卡上分别部署了0号专家、1号专家、2号专家和3号专家。0卡上的输入序列可能有部分token被分配给了1、2、3号专家,那么就需要把对应的token数量告知其他设备上的专家,这样的话,每个rank就知道自己在下一步进行all_to_all()时的输入输出切割策略。

第二次通信行为,是确定好token分配信息之后,各个设备把本地的token特征进行分割,然后互相发送、接收,这一步调用的是all_to_all()方法。举个例子,假如0卡上一共有12个token,经过Gate模块后,判定为0-2号token使用自己的expert处理,3-5号token使用1号专家处理,6-8号token使用2号专家处理,9-11号token使用3号专家,那么0号卡就会把这12个token的特征矩阵分成4份,把其中3份发给其他卡。同时,0号卡会接受其他卡发送过来的需要让0号专家处理的token特征。完成这一步后,各个卡上的专家就可以进行计算了。计算完成后,就需要进行第三次通信行为了。

第三次通信行为使用的也是all_to_all()方法,目的是把计算结果传回token原属的节点。继续上面的例子,1号专家计算完0号卡发送过来的0-2号token后,需要把计算结果返回给0号卡;2号专家计算完0号卡发送过来的3-5号token后,也需要把计算结果返回给0号卡,以此类推。

以上就是DeepSeek-R1 MOE模块的代码实现解析,大家还有什么问题呢?欢迎讨论!

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言