DeepSeek mHC的简单演示
2026-01-04
DeepSeek发布了最新的魔改版Residual Connection:Manifold Constrained Hyper-Connection.
思路
-
其基本思路是把旁路residual限制在某个集合上
- 文中用更“几何”的manifold一词表述;
- 退化的例子就是Kaiming的原版Residual Connection,约束是
residual = x - 本文则将residual projection matrix的谱范数限制在
-
类似的思路还可以在比如物理模拟中看到:
- 通过将物体的 transformation matrix 约束在 ,禁止物体形变,从而模拟刚体。
-
HC的基本思路应该是:
- 原本就有n个stream
- 在主线forward的时候,把n个stream合并为一个(pre-proj),通过这一层网络(),然后再打散回n个stream(post-proj)
- 支线复制输入x,通过一个res-proj进行信息混合之后,加回主线的输出
-
mHC对这个res-proj进行约束:
- 要求其为bistochastic matrix.
- 具体做法就是通过 sinkhorn 迭代直接将其映射到最接近的 doubly stochastic matrix 上。
简单实现(不含优化)
一种可能有错误的简单的代码实现在这里如下.
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops as ein
N_ITER = 20
def sinkhorn_knopp(mat: torch.Tensor) -> torch.Tensor:
for _ in range(N_ITER):
mat = mat / mat.sum(-2, keepdim=True)
mat = mat / mat.sum(-1, keepdim=True)
return mat
n = 4 # stream width
C = 256 # embedding dim
norm = nn.RMSNorm((n * C,))
phi_pre = nn.Parameter(torch.randn(n * C, n))
phi_post = nn.Parameter(torch.randn(n * C, n))
phi_res = nn.Parameter(torch.randn(n * C, n * n))
def broadcast_to_n_stream(xl: torch.Tensor) -> torch.Tensor:
return ein.repeat(xl, "... C -> ... n C", n=n)
def reduce_to_one_stream(xl: torch.Tensor) -> torch.Tensor:
return ein.reduce(xl, "... n C -> ... C", "mean")
def manifold_constrained_hyperconnection(xl: torch.Tensor, layer: nn.Module) -> torch.Tensor:
xl_vec = ein.rearrange(xl, "... n C -> ... (n C)")
xl_vec_prime = norm(xl_vec)
h_tilde_pre = alpha_pre * (xl_vec_prime @ phi_pre) + b_pre
h_tilde_post = alpha_post * (xl_vec_prime @ phi_post) + b_post
h_tilde_res = alpha_res * ein.rearrange((xl_vec_prime @ phi_res), "... (m n) -> ... m n", n=n) + b_res
h_pre = F.sigmoid(h_tilde_pre)
h_post = 2 * F.sigmoid(h_tilde_post)
h_res = sinkhorn_knopp(h_tilde_res.exp())
residual = ein.einsum(h_res, xl, "... m n, ... n C -> ... m C")
x_pre = ein.einsum(h_pre, xl, "... n, ... n C -> ... C")
layer_out = layer(x_pre)
x_post = ein.einsum(h_post, layer_out, "... n, ... C -> ... n C")
out = x_post + residual
return out