不通过反转正向传播的方式计算sinkhorn迭代的梯度
2026-01-05
问题设定
问题
- 输入矩阵: 。
- (element-wise)。
- 通过对 进行 Sinkhorn-knopp迭代,得到bistochastic matrix 。
- 损失函数: ,令 为已知梯度。
目标
对 的梯度:。
TLDR
通过使用CG方法求解下列方程:
可以得到 对 的梯度:
求解
求解线性系统
将上述方程改写成矩阵形式:
组装梯度
性质
1. 多解
考虑非零向量 。
根据bistochastic matrix性质 和 :
由于存在非零向量在 的零空间中,故 。
2. 不变量
虽然解 包含不确定的偏移量 ,但我们的计算目标是确定的。
将通解代入:
3. 形式变换
从原系统消元:
其中 是对称半正定的。
算法
- 准备右端项
- 构建半正定系统
- 用 CG 求解
- 构造解
- 组装结果
- 最终梯度
PyTorch 实现
import torch
def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P
for _ in range(iters):
R = R / R.sum(-2, keepdim=True)
R = R / R.sum(-1, keepdim=True)
return R, P
def batch_cg_solve_singular(A, b):
batch_size, n, _ = A.shape
x = torch.zeros_like(b)
r = b.clone()
p = r.clone()
rs_old = torch.einsum("bi,bi->b", r, r)
for i in range(n):
Ap = torch.einsum("bij,bj->bi", A, p)
pAp = torch.einsum("bi,bi->b", p, Ap)
alpha = rs_old / (pAp + 1e-11)
x += torch.einsum("b,bi->bi", alpha, p)
r -= torch.einsum("b,bi->bi", alpha, Ap)
rs_new = torch.einsum("bi,bi->b", r, r)
beta = rs_new / (rs_old + 1e-11)
p = r + torch.einsum("b,bi->bi", beta, p)
rs_old = rs_new
return x
def sinkhorn_backward_n_rank0(grad_R, R, cg_iters=10):
R_detached = R.detach()
G = grad_R
r = (R_detached * G).sum(dim=-1)
c = (R_detached * G).sum(dim=-2)
R_T = torch.einsum("bij->bji", R_detached)
RTR = torch.einsum("bij,bjk->bik", R_T, R_detached)
eye = torch.eye(n, device=R.device, dtype=R.dtype).unsqueeze(0).expand(batch_size, -1, -1)
S0 = eye - RTR
b = c - torch.einsum("bij,bj->bi", R_T, r)
v_tilde = batch_cg_solve_singular(S0, b)
u = r - torch.einsum("bij,bj->bi", R_detached, v_tilde)
v = v_tilde
M = u.unsqueeze(-1) + v.unsqueeze(-2)
grad_X = (G - M) * R_detached
return grad_X
> 代码开源于Github