友情链接:ACEJoy
DeepSeek-V2是DeepSeek团队最新发布的MoE(Mixture of Experts)架构的LLM(大型语言模型)底座。该模型拥有236B的总参数量和21B的每个token激活参数量,支持128K tokens的上下文长度。DeepSeek-V2的一个核心创新点就是Multi-head Latent Attention(MLA)。
Multi-head Latent Attention(MLA)简介
MLA对传统Transformer中的多头注意力机制(MHA)进行了改进,主要目标是:
- 降低推理时KV Cache的存储开销;
- 缓解GQA(Grouped-Query Attention)和MQA(Multi-Query Attention)等方法导致的模型性能损耗。
标准的MHA结构
在标准的MHA结构中,每个token的query、key和value通过参数矩阵映射得到,并分割成多个注意力头。每个头独立计算注意力权重并得到输出,这个过程虽然能捕捉丰富的上下文信息,但在推理时需要缓存大量的KV Cache。
MLA如何改进?
MLA通过对keys和values进行低秩联合压缩来降低KV Cache:
- 低秩Key-Value联合压缩:
[
\mathbf{c}_t^{KV} = W^{DKV} \mathbf{h}_t
]
[
\mathbf{k}_t^C = W^{UK} \mathbf{c}_t^{KV}
]
[
\mathbf{v}_t^C = W^{UV} \mathbf{c}_t^{KV}
]
其中,(\mathbf{c}_t^{KV})表示压缩后的隐向量,(W^{DKV})是降维映射矩阵,(W^{UK})和(W^{UV})是升维映射矩阵。在推理时,只需要缓存隐向量(\mathbf{c}_t^{KV}),显著减少了KV Cache的容量。 - Queries的低秩压缩:
[
\mathbf{c}_t^Q = W^{DQ} \mathbf{h}_t
]
[
\mathbf{q}_t^C = W^{UQ} \mathbf{c}_t^Q
]
这样即便不能减少KV Cache,但可以降低训练过程中的激活内存。
代码实现
以下是MLA在DeepSeek-V2中的Python代码实现片段:
class DeepSeekV2Attention(nn.Module):
def init(self, config: DeepSeekV2Config, layer_idx: Optional[int] = None):
…
self.w_dq = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.w_uq = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
self.w_dkv = nn.Linear(self.hidden_size, self.dc, bias=config.attention_bias)
self.w_uk = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)
self.w_uv = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, use_cache: bool = False, **kwargs):
bsz, q_len, _ = hidden_states.size()
q = self.w_uq(self.q_a_layernorm(self.w_dq(hidden_states))).view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
kv_seq_len = q.size(-2)
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
compressed_kv = self.w_dkv(hidden_states)
if past_key_value is not None:
compressed_kv = past_key_value.update(compressed_kv)
k = self.w_uk(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
v = self.w_uv(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.softmax_scale
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
outputs = (attn_weights,)
else:
outputs = ()
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
attn_output = self.out_proj(attn_output)
outputs = (attn_output,) + outputs
if use_cache:
outputs = outputs + (past_key_value,)
return outputs
```
结论
DeepSeek-V2通过引入Multi-head Latent Attention(MLA)结构,成功优化了传统的多头注意力机制(MHA),在保证模型性能的同时,显著降低了推理时KV Cache的存储开销。这不仅提高了模型的效率,也为未来的大模型架构设计提供了新的思路。
MLA的实现通过对queries、keys和values进行低秩压缩,减少了存储需求,缓解了因GQA和MQA方法导致的性能损耗。这种创新在深度学习模型的设计中具有重要的参考价值。
如果你对于DeepSeek-V2的MLA结构有更多的兴趣,建议查看其开源代码和详细文档,以便深入理解其工作机制和实现细节。