FLOPs 的计算

模型训练过程中大多数浮点运算都是矩阵乘法,对于一个 m×nm \times n 的矩阵 AA 和一个 n×pn \times p 的矩阵 BBA×BA \times B 需要 m×n×pm \times n \times p 次乘法和 m×n×pm \times n \times p 次加法,即需要 2mnp2mnp FLOPs。

Transformer Architecture 的 FLOPs 计算

alt text

alt text

Attention

Q,K,V transformation: 3×2Bsh23 \times 2Bsh^2

QKTQK^T: 2Bs2h2Bs^2h

attention over values: 2Bs2h2Bs^2h

post-attention linear projection: 2Bsh22Bsh^2

Feed Forward Network

linear h->4h: 8Bsh28Bsh^2

linear 4h->h: 8Bsh28Bsh^2

Total

forward: (6+2+8+8)Bsh2+(2+2)Bs2h=24Bsh2+4Bs2h(6 + 2 + 8 + 8)Bsh^2 + (2 + 2)Bs^2h = 24Bsh^2 + 4Bs^2h

backward 的 FLOPs 大致是 forward 的 2 倍,所以 forward + backward 的 FLOPs 大致是 72Bsh2+12Bs2h72Bsh^2 + 12Bs^2h

参考 Megatron-LM 2 APPENDIX