什么是混合精度训练?混合精度训练有什么用?
这里总结一下。
本文总结自kapathy的build gpt2
通常在训练过程中,model里面的数据默认都是torch.float32类型,
也就是用32bit的float型数据表示变量。
比如特征提取中提取的特征,描述子,网络的参数等,显示是torch.float32.
32bit的float带来的影响是比如拿它和int8比,占用内存会比较多,计算量会比较大。
所以训练时也就会比较慢,我们可以考虑丢弃一些精度来达到性能的提升。
看一下GPU支持的类型,比如nv a100吧,
可以看到实际上GPU是支持到float64的,但实际上模型训练时不需要那么高的精度,也会导致训练速度变慢,所以默认一般是float32.
精度下调的性能提升
float32最大可达到19.5TFLOPS,意味着可达到每秒19.5 trillion次float操作。
那么如果我们降低一些精度,比如降到TF32(见上图),每秒就可以达到156TFLOPS的性能,差不多达到了8倍性能提升。
如果进一步降低精度,到BFLOAT16, 就可以达到16倍性能提升。
注:右边*的性能指的是用稀疏化。
int8一般用于推理而不是训练。因为int8是均匀的空间,而训练时activation, weight都是正态分布的, 训练时要用float。
另一方面, 精度降低之后占的内存少,易于搬运,这涉及到memory bandwidth和模型的memory. 可以参考图中的GPU memory bandwidth.
解释一下memory bandwidth, 一般情况下data需要搬到GPU上然后运算,运算完再搬回去(涉及到GPU内存),但是受限于这个bandwidth, 明明有多余算力,但是data还没搬进来,就需要等,导致利用度不够高。然而如果你降低了精度,数据占的内存就会变小,一次就可以搬运更多,那么每次参与计算的就会更多,在这个搬运上也会提升性能。
小结一下,适当降低精度会让计算量减少从而提升性能,另一方面,会让数据占的内存减少,每次搬运的数据更多,从而在有限的memory bandwidth上达到性能提升。
图上有个名词叫tensor core, 现在介绍一下什么是tensor core.
它是a100中的一个instruction, 它做的事情是4x4的矩阵乘法。
矩阵的乘法会broke up成这些4x4的矩阵乘。
比如在transformer中很多linear layer就需要矩阵乘法,特别是最后一层的classify layer是一个大矩阵乘。
矩阵乘就通过tensor core来加速。
不同精度的数据
TF32
通过这个图可以看到TF32和FP32表示的范围是一样的,最左边的sign是符号位,中间8位range表示数字可表达的范围,它们是一样的,区别就是TF32舍弃了一些小数点位。整个只有19bit, 而不是32bit。
这些是在硬件上完成的,pytorch代码上是不可见的。
为什么叫混合精度呢,是因为你的input还是fp32, output还是fp32, 但是在内部计算上,后面的bit被舍弃了,为提升性能降低了一些精度。所以结果会类似是一个近似结果,但你会几乎看不出差别。
虽然说明书上写的用TF32会有8x性能提升,但实际上那只是矩阵乘的时候用了TF32, 其他部分仍然是FP32, 另外还受限于memory bandwidth, 所以实际上大概率是达不到的。
但是注意一点,这是a100支持的TF32,有的GPU可能不支持。
说了这么多,到底怎么用TF32训练呢。
只需要一行code. 用到了torch.set_float32_matmul_precision, 有兴趣的可以查看官方文档。它有一些参数,“highest”, ”high“,”medium“等, 其中"high"就是TF32. 默认是"highest", 也就是float32.
torch.set_float32_matmul_precision('high')#model = GPT(..)
#model.to(device)
#training process
BFLOAT16
每个float只有16bit.
它和FP16有什么区别呢,还是看上面的图,它的range和float32是一样的,比FP16表示的范围要宽,只不过精度进一步被cut.
另外当你用FP16训练时,由于它表示的范围比FP32要小,所以你还要做gradient scaling操作。用BF16不需要做gradient scaling.
具体怎么用BF16呢,你可以参考torch.autocast, 但你不需要考虑gradient scaling的部分。
官方文档上有说,autocasting时,不要在model或input处call half() or bfloat16()。
你只能在forward和计算loss处用BF16.
具体如下, 只需要加一句torch.autocast:
model = GPT(..)
model.to(device)optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(iterations):x, y = train_loader.next_batch()x, y = x.to(device), y.to(device)optimizer.zero_grad()#用BF16with torch.autocast(device_type=device, dtype=torch.bfloat16):logits, loss = model(x, y)loss.backward()optimizer.step()
但是你看transformer里面的embedding table的weight, 仍然是float32, 具体哪些模块能cast到BF16,你可以参考官方文档里面的CUDA Ops that can autocast to float16. 需要矩阵乘的时候会cast, 很多操作仍然会保持float32不变,比如norm。