Triton使用和Softmax实现

参考 OpenAI Triton 主页

参考 Triton论文

参考 GPU MODE Lecture 14: Practitioners Guide to Triton

从Trinton主页引用的话

现代 GPU 的架构大致可以分为三个主要组件 ——DRAM、SRAM 和 ALU—— 在优化 CUDA 代码时必须考虑每个组件:

  1. 来自 DRAM 的内存传输必须合并为更大的事务,以利用现代内存接口的大型总线宽度。
  2. 数据在被再次使用之前,必须手动存储到SRAM中,并且要对数据进行管理,以便在检索数据时尽量减少共享内存存储体冲突的情况。
  3. 计算必须在流式多处理器(SM)之间和内部仔细分区和调度,以促进指令 / 线程级并行性并利用专用 ALU(例如Tensor Core)。

这几句话可能比较抽象,下面给一下这几个组件的指标可能感受更直观,参考Which GPU(s) to Get for Deep Learning: My Experience and Advice for Using GPUs in Deep Learning
其中指出:

  • 全局内存访问(高达 80GB):~380 个周期
  • L2 缓存:~200 个周期
  • L1 缓存或共享内存访问(每个流式多处理器高达 128 kb):~34 个周期
  • 融合乘法和加法,ab+c(FFMA): 4 个周期
  • Tensor Core 矩阵乘法:1 个周期

每个操作总是由 32 个线程组成的Warp执行,Warp中的线程必须相互等待。GPU 上的所有内存操作都针对warp进行了优化。

根据Simplifying CUDA kernels with Triton: A Pythonic Approach to GPU Programming的说法,GPU中的HBM(High Bandwidth Memory)等价于我们讲的Global Memory,SRAM对应的是L1和L2 Cache对应的是Shared Memory,这几个词在一些文档中可能会有不同的叫法,但是意思是一样的。

A100中的内存带宽约为 2TB/s,L1 缓存带宽:~100-200 TB/s 理论带宽,L2 缓存带宽:~4-7 TB/s 理论带宽。

再看OpenAI的三条说明的意思就是:

  1. 因为DRAM很大,比较容易占满总线带宽,所以尽量合并传输的事务可以减少传输的时间,让高速公路跑满。
  2. 如果数据要重复利用,反复参与计算,尽量让他们在SRAM当中能够缓存住,比如L1的读取只要34个cyle,能比从L2中快6到7倍。
  3. 尽量跑满并行度,并且利用更高效的计算单元,比如Tensor Core。

这个是OpenAI给出的GPU架构的简图,我们需要明确不同内存,缓存,和执行单元的周期之间的关系就比较好理解GPU计算当中的性能瓶颈。

Triton的目标其实就是优化 HBM -> SRAM -> 寄存器 的带宽,这在Torch里面直接实现不了,通过一些融合算子是可以减少写回到HBM的。

Triton的文档给出的很多实现的代码,可能都不太奏效了,笔者自己测试下来并没有超过torch本身的实现,
可能torch本身也再不断改进吧,这些差别很快就超越了,但是在一些写自定义融合算子方面应该还是比较有优势的。

Triton的使用

和CUDA对应的关系:

  • 程序(Program):处理数据块的kernel实例。
  • PID(程序 ID):等同于 CUDA 中的块 ID。
  • 向量化操作:在多个数据点上同时操作(triton不需要用户关心向量操作的并行化)。

先给出变量和修饰器的解释,大部分文档都混在注释里面不是很好阅读,我觉得先介绍一些简单概念再看代码会比较好一点。

@triton.jit 装饰器表示这个函数会被编译。
tl.constexpr 代表常量表达式,可以让编译器在编译期间直接求值,可以当作常量使用了。
BLOCK_SIZE对于GPU来说是比较固定的,因为一个block是有threads数上限的。
通过执行cuda-samples中的deviceQuery
可以发现L40的显卡BLOCK_SIZE最大是1024,大部分显卡应该都是这个固定大小。

1
2
3
4
5
6
7
8
9
10
Total number of registers available per block: 65536  
Warp size: 32
L2 Cache Size: 100663296 bytes(96MB)
Maximum number of threads per multiprocessor: 1536
Maximum number of threads per block: 1024
Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
Max dimension size of a grid size (x,y,z): (2147483647, 65535, 65535)
Total amount of shared memory per block: 49152 bytes(48KB)
Total shared memory per multiprocessor: 102400 bytes (100KB)
Total number of registers available per block: 65536

pid = tl.program_id(axis=0) 应该是对应的CUDA中的threadIdx.x的作用,对应block的一维下标,
pid = tl.program_id(axis=1) 应该是对应的CUDA中的threadIdx.y的作用,对应block的二维下标。

autotune是一个黑盒优化,通过内部的小benchmark的方式去基于key的变量,优化configs里面的参数。
下面是一个只有两个配置的搜索空间,当n_elements的值发生变化的时候,会自动选择最优的配置。

1
2
3
4
5
@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])
@triton.jit

BLOCK_SIZE表示的是一个program负责的BLOCK大小,放在triton的语境下更像是L2 Cache的大小
但是cuda当中的block是包含n个thread的,表示的是并行线程的大小
笔者的这个说法可能不太精准,但是这两种风格导致Cuda写一些element wise的操作比较合适.
每个element wise的操作都是一个thread,这样可以充分利用GPU的并行性。
而triton比较适合一些Reduce操作,例如对数据(也就是矩阵)切BLOCK,然后每个kernel去负责一个block,
他的好处就是比如softmax这样的在行上做reduce操作会比较直观,而矩阵乘法也可以沿着MxK,KxN的维度,沿着不同的维度切块。
Triton能够帮你把矩阵乘法优化得很不错,虽然可能还比不上精准手写的Cuda算子。

Triton的范式和CUDA的Single Instruction, Multiple Thread (SIMT)不一样,官网给出了一个简化的例子。

这是CUDA like的写法,每个threadId.x代表的线程只算一个element

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
BLOCK = 512

# This is a GPU kernel in Numba.
# Different instances of this
# function may run in parallel.
@jit
def add(X, Y, Z, N):
# In Numba/CUDA, each kernel
# instance itself uses an SIMT execution
# model, where instructions are executed in
# parallel for different values of threadIdx
tid = threadIdx.x
bid = blockIdx.x
# scalar index
idx = bid * BLOCK + tid
if id < N:
# There is no pointer in Numba.
# Z,X,Y are dense tensors
Z[idx] = X[idx] + Y[idx]


...
grid = (ceil_div(N, BLOCK),)
block = (BLOCK,)
add[grid, block](x, y, z, x.shape[0])

Triton文档中的Matrix乘法简化来说就是并行计算M*N个block(沿着K所代表的维度)。
这是Triton的写法,每个Program负责了一个block,他就少了一个block的切分维度:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
BLOCK = 512

# This is a GPU kernel in Triton.
# Different instances of this
# function may run in parallel.
@jit
def add(X, Y, Z, N):
# In Triton, each kernel instance
# executes block operations on a
# single thread: there is no construct
# analogous to threadIdx
pid = program_id(0)
# block of indices
idx = pid * BLOCK + arange(BLOCK)
mask = idx < N
# Triton uses pointer arithmetics
# rather than indexing operators
x = load(X + idx, mask=mask)
y = load(Y + idx, mask=mask)
store(Z + idx, x + y, mask=mask)


...
grid = (ceil_div(N, BLOCK),)
# no thread-block
add[grid](x, y, z, x.shape[0])

笔者比较疑惑,单就这两个代码他们的并行度貌似是不一样的,难道是把block那一层隐式的放在了loadstore当中,他的loadstore其实是隐含了并行能力的。
援引知乎的文章Triton中一直是以Block为中心来计算,直到Lowering到LLVM和PTX才会转为Thread为中心的计算,而这些对于使用Block抽象进行编程的用户来说都是无感的。是符合笔者预期的,triton简化了CUDA的写法,block具体的线程数的,每个线程处理多少元素,triton自己会去帮你处理。

当使用 triton 的时候,x = tl.load(x_ptr + offsets, mask=mask)时,我们正在加载到 L2 缓存 或者叫 SRAM 中。

根据Torch的blog,以及参考 OpenAI/Triton MLIR 迁移工作简介,Triton编译的过程是@triton.jit装饰器通过遍历提供的 Python 函数的抽象语法树(AST)来工作,以便使用通用的 SSA 构造算法即时生成 Triton-IR。
然后,生成的 IR 代码被我们的编译器后端简化、优化和自动并行化,最后被转换成高质量的 LLVM IR,最终是 PTX(Nvidia GPU的汇编),可以在最近的Nvidia GPU上执行。

矩阵乘法

这段代码是基于K切BLOCK,比上面的代码要好理解一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@triton.jit
def simple_mm(a, b, o, k, n,
K_BLOCK_SIZE: tl.constexpr = 64,
) -> None:
# a -> Matrix of size M x K and b -> Matrix of size K x N
# K is the common inner dimension
num_blocks = k//K_BLOCK_SIZE + 1
row_id = tl.program_id(0)
col_id = tl.program_id(1)

# Lets pick one column and one row and do a dot product
# Like the 1-D example we dont want to look at the entire row/column
# We are making use of the fact that each row/column will be of the size
# 'k' which is the inner common dimension of these matrices
# But this will only be a part of the dot product so we have to keep track of many to cover the entire column or row.

# What we are going to do is to access block size elements from the column
# and the row and compute the dot product and keep adding to a value till
# we run out of numbers
value = 0.
for k_id in range(num_blocks):
row_start = row_id * k + k_id * K_BLOCK_SIZE
row_offsets = tl.arange(0, K_BLOCK_SIZE) + row_start
# The masks are a little more trickier as we cant just see if its
# less than 'k'. We need to account for the row we are in
row_masks = row_offsets < (row_id + 1) * k
row = tl.load(a + row_offsets, mask=row_masks) # Load this into the GPU SRAM

col_start = (K_BLOCK_SIZE * k_id)
col_offsets = n * (tl.arange(0, K_BLOCK_SIZE) + col_start) + col_id # 0, n, 2n || 3n, 4n, 5n for a block size of 3 for eg
col_masks = col_offsets/n < k
col = tl.load(b + col_offsets, mask=col_masks)
value += tl.sum(row * col)

output_offset = row_id * n + col_id
tl.store(o + output_offset, value)

BLOCK_SIZE 和 GROUP_SIZE 的优化。

一次计算的时候尽量用满L2 Cache,所以可以把多个BLOCK放到一个GROUP里面,这个GROUP变成了grid的切分,但在GROUP里面我们
再去做BLOCK级别的计算,要计算好对应的线性空间中的stride。

Softmax

下面的代码只能在n_cols小于BLOCK_SIZE的数据上运行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import triton
import triton.language as tl

@triton.jit
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
# row index
m = tl.program_id(0)
# col indices
# this specific kernel only works for matrices that
# have less than BLOCK_SIZE columns
BLOCK_SIZE = 1024
n = tl.arange(0, BLOCK_SIZE)
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n * stride_xn
# load input data; pad out-of-bounds elements with 0
x = tl.load(X, mask=n < N, other=-float('inf'))
# compute numerically-stable softmax
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# write back to Y
Y = Y + m * stride_ym + n * stride_yn
tl.store(Y, y, mask=n < N)

import torch
# Allocate input/output tensors
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)
# SPMD launch grid
grid = (X.shape[0], )
# enqueue GPU kernel
softmax[grid](Y, Y.stride(0), Y.stride(1),
X, X.stride(0), X.stride(1),
X.shape[0] , X.shape[1])