From f70c62d69c611dc6c8d736b01c52abe3a184baa7 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 29 Jul 2025 18:23:40 -0700 Subject: [PATCH] Simple gemm example --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/gemms/simple_gemm.cu | 47 +++++++++++++++++++++++++++ mlx/backend/cuda/gemms/simple_gemm.h | 18 ++++++++++ mlx/backend/cuda/matmul.cpp | 9 +++++ 4 files changed, 75 insertions(+) create mode 100644 mlx/backend/cuda/gemms/simple_gemm.cu create mode 100644 mlx/backend/cuda/gemms/simple_gemm.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 2e12c8c3e..6d9857343 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -26,6 +26,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp diff --git a/mlx/backend/cuda/gemms/simple_gemm.cu b/mlx/backend/cuda/gemms/simple_gemm.cu new file mode 100644 index 000000000..12ceda068 --- /dev/null +++ b/mlx/backend/cuda/gemms/simple_gemm.cu @@ -0,0 +1,47 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/steel/gemm.cuh" +#include "mlx/dtype_utils.h" + +namespace mlx::core::cu { + +void simple_gemm( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + cu::CommandEncoder& enc) { + enc.set_input_array(a); + enc.set_input_array(b); + enc.set_output_array(out); + dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) { + using DataType = cuda_type_t; + constexpr int BM = 128; + constexpr int BN = 128; + constexpr int BK = 64; + + auto kernel = ab_t_aligned; + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304); + cudaFuncSetAttribute( + kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + dim3 grid(N / BN, M / BM); + enc.add_kernel_node( + kernel, + grid, + 4 * WARP_SIZE, + 2 * sizeof(DataType) * (BM * BK + BN * BK), + a.data(), + b.data(), + out.data(), + N, + K); + }); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/simple_gemm.h b/mlx/backend/cuda/gemms/simple_gemm.h new file mode 100644 index 000000000..e89e03310 --- /dev/null +++ b/mlx/backend/cuda/gemms/simple_gemm.h @@ -0,0 +1,18 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +void simple_gemm( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + cu::CommandEncoder& enc); + +} diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index b11fae538..562837f78 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/gemv.h" +#include "mlx/backend/cuda/gemms/simple_gemm.h" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" @@ -11,6 +12,7 @@ #include namespace mlx::core { + namespace { std::tuple @@ -95,6 +97,13 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { return; } + if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed && + b_transposed && batch_count == 1 && + env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) { + cu::simple_gemm(a, b, out, M, N, K, encoder); + return; + } + ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt CublasGemm gemm(