Triton Tensor Descriptor: 茴字的第三种写法

今天我们来介绍 Triton 中的第三种进行 tensor 指针运算的 API:Tensor Descriptor。内容来自triton 文档

关于 triton 的基本概念

Tensor Descriptor的用法

创建

desc = tl.make_tensor_descriptor(
pointer,
shape=[M, N],
strides=[N, 1],
block_shape=[M_BLOCK, N_BLOCK],
)

其中:

读写

value = desc.load([moffset, noffset])
desc.store([moffset, noffset], tl.abs(value))

例子

例1

@triton.jit
def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
desc = tl.make_tensor_descriptor(
in_out_ptr,
shape=[M, N],
strides=[N, 1],
block_shape=[M_BLOCK, N_BLOCK],
)

moffset = tl.program_id(0) * M_BLOCK
noffset = tl.program_id(1) * N_BLOCK

value = desc.load([moffset, noffset])
desc.store([moffset, noffset], tl.abs(value))


M, N = 256, 256
x = torch.randn(M, N, device="cuda")
M_BLOCK, N_BLOCK = 32, 32
grid = (M / M_BLOCK, N / N_BLOCK)
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)

例2:Flash Attention

Flash Attention with Tensor Descriptor