DeepGEMM is an FP8 GEMM library for AI training and inference, designed specifically for dense and mixture of experts (MoE) matrix operations, providing robust support for the training and inference of DeepSeek V3 and R1 models.
DeepSeek has successively open-sourced FlashMLA and DeepEP, and today it introduces DeepGEMM, a matrix multiplication library optimized specifically for Hopper architecture GPUs. This library supports standard matrix computation and mixture-of-experts (MoE) computation, providing robust support for the training and inference of DeepSeek-V3/R1, achieving high performance of 1350+ FP8 TFLOPS on Hopper GPUs.
DeepGEMM is designed to be simple and efficient, with only about 300 lines of core code, while outperforming existing solutions in most matrix sizes. The library supports three data alignments: a standard alignment and two special alignments (sequential and masked) designed for hybrid expert models. deepGEMM uses on-the-fly compilation, eliminating the need to compile at installation time, and has a clear, easy-to-understand code structure that makes it ideal for learning GPU optimisation techniques.
DeepGEMM performs well in a variety of computational scenarios. For standard matrix multiplication, speedups range from 1.0 to 2.7 times compared to the optimised implementation based on CUTLASS 3.6. The most significant speedups, up to 2.7 times, were achieved for small batches of data (M=64 or 128). For the computation of hybrid expert models, the two special data alignments offered by DeepGEMM also offer significant advantages. The sequential arrangement is suitable for both training and batch inference phases, with speedups of about 1.1 to 1.2 times, while the masked arrangement is designed for real-time inference and supports the use of CUDA graph techniques, also with speedups of 1.1 to 1.2 times.
M | N | K | Computation | Memory bandwidth | Speedup |
---|---|---|---|---|---|
64 | 2112 | 7168 | 206 TFLOPS | 1688 GB/s | 2.7x |
64 | 24576 | 1536 | 289 TFLOPS | 2455 GB/s | 1.7x |
64 | 32768 | 512 | 219 TFLOPS | 2143 GB/s | 1.8x |
64 | 7168 | 16384 | 336 TFLOPS | 2668 GB/s | 1.4x |
64 | 4096 | 7168 | 287 TFLOPS | 2320 GB/s | 1.4x |
64 | 7168 | 2048 | 295 TFLOPS | 2470 GB/s | 1.7x |
128 | 2112 | 7168 | 352 TFLOPS | 1509 GB/s | 2.4x |
128 | 24576 | 1536 | 535 TFLOPS | 2448 GB/s | 1.6x |
128 | 32768 | 512 | 358 TFLOPS | 2103 GB/s | 1.5x |
128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x |
128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x |
128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x |
4096 | 2112 | 7168 | 1058 TFLOPS | 527 GB/s | 1.1x |
4096 | 24576 | 1536 | 990 TFLOPS | 786 GB/s | 1.0x |
4096 | 32768 | 512 | 590 TFLOPS | 1232 GB/s | 1.0x |
4096 | 7168 | 16384 | 1358 TFLOPS | 343 GB/s | 1.2x |
4096 | 4096 | 7168 | 1304 TFLOPS | 500 GB/s | 1.1x |
4096 | 7168 | 2048 | 1025 TFLOPS | 697 GB/s | 1.1x |
#Groups | M per group | N | K | Computation | Memory bandwidth | Speedup |
---|---|---|---|---|---|---|
4 | 8192 | 4096 | 7168 | 1297 TFLOPS | 418 GB/s | 1.2x |
4 | 8192 | 7168 | 2048 | 1099 TFLOPS | 681 GB/s | 1.2x |
8 | 4096 | 4096 | 7168 | 1288 TFLOPS | 494 GB/s | 1.2x |
8 | 4096 | 7168 | 2048 | 1093 TFLOPS | 743 GB/s | 1.1x |
#Groups | M per group | N | K | Computation | Memory bandwidth | Speedup |
---|---|---|---|---|---|---|
1 | 1024 | 4096 | 7168 | 1233 TFLOPS | 924 GB/s | 1.2x |
1 | 1024 | 7168 | 2048 | 925 TFLOPS | 968 GB/s | 1.2x |
2 | 512 | 4096 | 7168 | 1040 TFLOPS | 1288 GB/s | 1.2x |
2 | 512 | 7168 | 2048 | 916 TFLOPS | 1405 GB/s | 1.2x |
4 | 256 | 4096 | 7168 | 932 TFLOPS | 2064 GB/s | 1.1x |
4 | 256 | 7168 | 2048 | 815 TFLOPS | 2047 GB/s | 1.2x |
To use DeepGEMM, you need Hopper architecture GPUs with sm_90a support, Python 3.8 or higher, CUDA 12.3 or higher (12.8 or higher is recommended for best performance), PyTorch 2.1 or higher, and CUTLASS 3.6 or higher.
Development
# Submodule must be cloned
git clone --recursive [email protected]:deepseek-ai/DeepGEMM.git
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop
# Test JIT compilation
python tests/test_jit.py
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py
Installation
python setup.py install
Finally, import deep_gemm and you’re done!