Simple gemm example

This commit is contained in:
Angelos Katharopoulos 2025-07-29 18:23:40 -07:00
parent 0c5fc63a36
commit f70c62d69c
4 changed files with 75 additions and 0 deletions

View File

@ -26,6 +26,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp

View File

@ -0,0 +1,47 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/dtype_utils.h"
namespace mlx::core::cu {
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 64;
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304);
cudaFuncSetAttribute(
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
dim3 grid(N / BN, M / BM);
enc.add_kernel_node(
kernel,
grid,
4 * WARP_SIZE,
2 * sizeof(DataType) * (BM * BK + BN * BK),
a.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
N,
K);
});
}
} // namespace mlx::core::cu

View File

@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@ -4,6 +4,7 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/gemms/simple_gemm.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
@ -11,6 +12,7 @@
#include <numeric>
namespace mlx::core {
namespace {
std::tuple<bool, int64_t, array>
@ -95,6 +97,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 &&
env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) {
cu::simple_gemm(a, b, out, M, N, K, encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(