不通过反转正向传播的方式计算sinkhorn迭代的梯度

问题设定

问题

  1. 输入矩阵:
  2. (element-wise)。
  3. 通过对 进行 Sinkhorn-knopp迭代,得到bistochastic matrix
  4. 损失函数: ,令 为已知梯度。

目标

的梯度:

TLDR

通过使用CG方法求解下列方程:

可以得到 的梯度:

求解

求解线性系统

将上述方程改写成矩阵形式:

组装梯度

性质

1. 多解

考虑非零向量

根据bistochastic matrix性质

由于存在非零向量在 的零空间中,故

2. 不变量

虽然解 包含不确定的偏移量 ,但我们的计算目标是确定的。

将通解代入:

3. 形式变换

从原系统消元:

其中 是对称半正定的。

算法

  1. 准备右端项
  1. 构建半正定系统
  1. 用 CG 求解
  1. 构造解
  1. 组装结果
  1. 最终梯度

PyTorch 实现

import torch

def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P
for _ in range(iters):
R = R / R.sum(-2, keepdim=True)
R = R / R.sum(-1, keepdim=True)
return R, P

def batch_cg_solve_singular(A, b):
batch_size, n, _ = A.shape
x = torch.zeros_like(b)
r = b.clone()
p = r.clone()
rs_old = torch.einsum("bi,bi->b", r, r)

for i in range(n):
Ap = torch.einsum("bij,bj->bi", A, p)
pAp = torch.einsum("bi,bi->b", p, Ap)
alpha = rs_old / (pAp + 1e-11)
x += torch.einsum("b,bi->bi", alpha, p)
r -= torch.einsum("b,bi->bi", alpha, Ap)
rs_new = torch.einsum("bi,bi->b", r, r)
beta = rs_new / (rs_old + 1e-11)
p = r + torch.einsum("b,bi->bi", beta, p)
rs_old = rs_new

return x

def sinkhorn_backward_n_rank0(grad_R, R, cg_iters=10):
R_detached = R.detach()
G = grad_R

r = (R_detached * G).sum(dim=-1)
c = (R_detached * G).sum(dim=-2)

R_T = torch.einsum("bij->bji", R_detached)
RTR = torch.einsum("bij,bjk->bik", R_T, R_detached)
eye = torch.eye(n, device=R.device, dtype=R.dtype).unsqueeze(0).expand(batch_size, -1, -1)

S0 = eye - RTR
b = c - torch.einsum("bij,bj->bi", R_T, r)

v_tilde = batch_cg_solve_singular(S0, b)
u = r - torch.einsum("bij,bj->bi", R_detached, v_tilde)
v = v_tilde

M = u.unsqueeze(-1) + v.unsqueeze(-2)
grad_X = (G - M) * R_detached

return grad_X

> 代码开源于Github