DeepGEMM es una biblioteca GEMM FP8 para el entrenamiento y la inferencia de IA, diseñada específicamente para operaciones matriciales densas y de mezcla de expertos (MoE), que proporciona un soporte robusto para el entrenamiento y la inferencia de modelos DeepSeek V3 y R1.
DeepSeek ha abierto sucesivamente sus fuentes FlashMLA y DeepEPy hoy presenta DeepGEMMuna biblioteca de multiplicación de matrices optimizada específicamente para las GPU de arquitectura Hopper. Esta biblioteca admite el cálculo de matrices estándar y el cálculo de mezclas de expertos (MoE), lo que proporciona un sólido soporte para el entrenamiento y la inferencia de DeepSeek-V3/R1, logrando un alto rendimiento de 1350+ FP8 TFLOPS en GPUs Hopper.
DeepGEMM está diseñada para ser sencilla y eficiente, con sólo unas 300 líneas de código central, al tiempo que supera a las soluciones existentes en la mayoría de los tamaños de matriz. La biblioteca admite tres alineaciones de datos: una alineación estándar y dos alineaciones especiales (secuencial y enmascarada) diseñadas para modelos expertos híbridos. deepGEMM utiliza compilación sobre la marcha, lo que elimina la necesidad de compilar en el momento de la instalación, y tiene una estructura de código clara y fácil de entender que la hace ideal para aprender técnicas de optimización en la GPU.
DeepGEMM obtiene buenos resultados en diversos escenarios computacionales. Para la multiplicación de matrices estándar, el aumento de velocidad oscila entre 1,0 y 2,7 veces en comparación con la implementación optimizada basada en CUTLASS 3.6. Los aumentos de velocidad más significativos, de hasta 2,7 veces, se lograron para lotes pequeños de datos (M=64 o 128). Para el cálculo de modelos expertos híbridos, las dos alineaciones de datos especiales que ofrece DeepGEMM también ofrecen ventajas significativas. La disposición secuencial es adecuada tanto para las fases de entrenamiento como de inferencia por lotes, con aumentos de velocidad de entre 1,1 y 1,2 veces, mientras que la disposición enmascarada está diseñada para la inferencia en tiempo real y admite el uso de técnicas de grafos CUDA, también con aumentos de velocidad de entre 1,1 y 1,2 veces.
M | N | K | Cómputo | Ancho de banda de memoria | Aceleración |
---|---|---|---|---|---|
64 | 2112 | 7168 | 206 TFLOPS | 1688 GB/s | 2.7x |
64 | 24576 | 1536 | 289 TFLOPS | 2455 GB/s | 1.7x |
64 | 32768 | 512 | 219 TFLOPS | 2143 GB/s | 1.8x |
64 | 7168 | 16384 | 336 TFLOPS | 2668 GB/s | 1.4x |
64 | 4096 | 7168 | 287 TFLOPS | 2320 GB/s | 1.4x |
64 | 7168 | 2048 | 295 TFLOPS | 2470 GB/s | 1.7x |
128 | 2112 | 7168 | 352 TFLOPS | 1509 GB/s | 2.4x |
128 | 24576 | 1536 | 535 TFLOPS | 2448 GB/s | 1.6x |
128 | 32768 | 512 | 358 TFLOPS | 2103 GB/s | 1.5x |
128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x |
128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x |
128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x |
4096 | 2112 | 7168 | 1058 TFLOPS | 527 GB/s | 1.1x |
4096 | 24576 | 1536 | 990 TFLOPS | 786 GB/s | 1.0x |
4096 | 32768 | 512 | 590 TFLOPS | 1232 GB/s | 1.0x |
4096 | 7168 | 16384 | 1358 TFLOPS | 343 GB/s | 1.2x |
4096 | 4096 | 7168 | 1304 TFLOPS | 500 GB/s | 1.1x |
4096 | 7168 | 2048 | 1025 TFLOPS | 697 GB/s | 1.1x |
#Grupos | M por grupo | N | K | Cómputo | Ancho de banda de memoria | Aceleración |
---|---|---|---|---|---|---|
4 | 8192 | 4096 | 7168 | 1297 TFLOPS | 418 GB/s | 1.2x |
4 | 8192 | 7168 | 2048 | 1099 TFLOPS | 681 GB/s | 1.2x |
8 | 4096 | 4096 | 7168 | 1288 TFLOPS | 494 GB/s | 1.2x |
8 | 4096 | 7168 | 2048 | 1093 TFLOPS | 743 GB/s | 1.1x |
#Grupos | M por grupo | N | K | Cómputo | Ancho de banda de memoria | Aceleración |
---|---|---|---|---|---|---|
1 | 1024 | 4096 | 7168 | 1233 TFLOPS | 924 GB/s | 1.2x |
1 | 1024 | 7168 | 2048 | 925 TFLOPS | 968 GB/s | 1.2x |
2 | 512 | 4096 | 7168 | 1040 TFLOPS | 1288 GB/s | 1.2x |
2 | 512 | 7168 | 2048 | 916 TFLOPS | 1405 GB/s | 1.2x |
4 | 256 | 4096 | 7168 | 932 TFLOPS | 2064 GB/s | 1.1x |
4 | 256 | 7168 | 2048 | 815 TFLOPS | 2047 GB/s | 1.2x |
Para utilizar DeepGEMM, se necesitan GPUs de arquitectura Hopper compatibles con sm_90a, Python 3.8 o superior, CUDA 12.3 o superior (se recomienda 12.8 o superior para obtener el mejor rendimiento), PyTorch 2.1 o superior y CUTLASS 3.6 o superior.
Desarrollo
El submódulo # debe clonarse
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
# Crear enlaces simbólicos para directorios de inclusión de terceros (CUTLASS y CuTe)
python setup.py desarrollar
# Prueba de compilación JIT
python pruebas/prueba_jit.py
# Probar todos los implementos GEMM (normal, agrupado-contiguo y agrupado-enmascarado)
python pruebas/prueba_core.py
Instalación
python setup.py instalar
Por último, importe deep_gemm ¡y listo!