From cf5eef095d30c3b426f32266eabb3afb54d4fa0f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 14 Aug 2025 12:29:53 -0700 Subject: [PATCH] tmp --- mlx/backend/cuda/CMakeLists.txt | 3 + mlx/backend/cuda/gemms/simple_gemm.cu | 34 +++++- mlx/backend/cuda/matmul.cpp | 11 +- mlx/backend/cuda/steel/gemm.cuh | 158 ++++++++++++++++++++------ mlx/backend/cuda/steel/tiles.cuh | 89 +++++++++++++-- mlx/backend/cuda/steel/utils.cuh | 6 +- 6 files changed, 248 insertions(+), 53 deletions(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d40eeec99..d307d108b 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -90,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") +# Keep ptx around for inspection +target_compile_options(mlx PRIVATE "$<$:--keep>") + # Enable calling host constexpr functions from device. This is needed because # the constexpr version of isnan is host only. target_compile_options( diff --git a/mlx/backend/cuda/gemms/simple_gemm.cu b/mlx/backend/cuda/gemms/simple_gemm.cu index 7598b3795..e8320f2cd 100644 --- a/mlx/backend/cuda/gemms/simple_gemm.cu +++ b/mlx/backend/cuda/gemms/simple_gemm.cu @@ -5,8 +5,29 @@ #include "mlx/backend/cuda/steel/gemm.cuh" #include "mlx/dtype_utils.h" +#include + namespace mlx::core::cu { +namespace { + +template +static void configure_smem(Kernel kernel, int SM) { + static bool done = false; + if (done) { + return; + } + std::cout << "configuring" << std::endl; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM); + cudaFuncSetAttribute( + kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxShared); + done = true; +} + +} // namespace + void simple_gemm( const array& a, const array& b, @@ -23,17 +44,20 @@ void simple_gemm( constexpr int BM = 128; constexpr int BN = 128; constexpr int BK = 32; + constexpr int PIPE = 3; + constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK); + constexpr int WM = 2; + constexpr int WN = 4; - auto kernel = ab_t_aligned; - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); + auto kernel = ab_t_aligned; + configure_smem(kernel, SM); dim3 grid(N / BN, M / BM); enc.add_kernel_node( kernel, grid, - 8 * WARP_SIZE, - 4 * sizeof(DataType) * (BM * BK + BN * BK), + WM * WN * WARP_SIZE, + SM, a.data(), b.data(), out.data(), diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index a5307ee6b..e30afb36b 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -16,6 +16,11 @@ namespace mlx::core { namespace { +int get_test_gemm() { + static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0); + return t; +} + std::tuple check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; @@ -99,15 +104,13 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { } if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed && - b_transposed && batch_count == 1 && - env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) { + b_transposed && batch_count == 1 && get_test_gemm() == 1) { cu::simple_gemm(a, b, out, M, N, K, encoder); return; } if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed && - b_transposed && batch_count == 1 && - env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 2) { + b_transposed && batch_count == 1 && get_test_gemm() == 2) { cu::cutlass_gemm(a, b, out, M, N, K, encoder); return; } diff --git a/mlx/backend/cuda/steel/gemm.cuh b/mlx/backend/cuda/steel/gemm.cuh index 31ba00fcf..5050ac7bd 100644 --- a/mlx/backend/cuda/steel/gemm.cuh +++ b/mlx/backend/cuda/steel/gemm.cuh @@ -8,20 +8,19 @@ template __device__ inline void gemm_ab_t( RegisterTile& C, SharedTile& As, - SharedTile& Bs, - int lane_row_a, - int lane_row_b, - int lane_col) { + SharedTile& Bs, + RegisterTileLoader>& rloader_a, + RegisterTileLoader>& rloader_b) { RegisterTile A[2]; RegisterTile B[2]; - A[0].load(As, As.base_addr(), lane_row_a, lane_col); - B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col); + rloader_a.load(A[0], As.base_addr(), 0); + rloader_b.load(B[0], Bs.base_addr(), 0); MLX_UNROLL for (int k = 1; k < BK / 16; k++) { - A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16); - B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16); + rloader_a.load(A[k & 1], As.base_addr(), k); + rloader_b.load(B[k & 1], Bs.base_addr(), k); mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]); } @@ -33,25 +32,91 @@ __device__ inline void gemm_ab_t( * * Computes A @ B.T when A and B are all aligned with the block sizes. */ -template -__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { - constexpr int WARPS_M = 4; - constexpr int WARPS_N = 2; - constexpr int NUM_WARPS = WARPS_M * WARPS_N; - constexpr int WARP_STEP_M = BM / WARPS_M; - constexpr int WARP_STEP_N = BN / WARPS_N; - constexpr int PIPE = 4; +// template +//__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1) +// void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { +// constexpr int NUM_WARPS = WM * WN; +// constexpr int WARP_STEP_M = BM / WM; +// constexpr int WARP_STEP_N = BN / WN; +// +// // Precompute some offsets for each thread +// const int warpid = threadIdx.x / 32; +// const int laneid = threadIdx.x % 32; +// const int wm = warpid / WN; +// const int wn = warpid % WN; +// const int offset_m = wm * WARP_STEP_M; +// const int offset_n = wn * WARP_STEP_N; +// +// // Allocate shared memory +// extern __shared__ char shmem[]; +// SharedTile(&as)[PIPE] = +// *(SharedTile(*)[PIPE])(&shmem[0]); +// SharedTile(&bs)[PIPE] = +// *(SharedTile(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]); +// +// // Move the global pointers to the tile +// a += blockIdx.y * BM * K; +// b += blockIdx.x * BN * K; +// y += blockIdx.y * BM * N + blockIdx.x * BN; +// +// // Make the loaders to/from SMEM +// SharedTileLoader> sloader_a(a, K); +// SharedTileLoader> sloader_b(b, K); +// RegisterTileLoader> rloader_a(offset_m, laneid); +// RegisterTileLoader> rloader_b(offset_n, laneid); +// +// // Start the SM pipeline +// MLX_UNROLL +// for (int i = 0; i < PIPE - 1; i++) { +// sloader_a.load_async(as[i].base_addr()); +// sloader_b.load_async(bs[i].base_addr()); +// cp_async_commit(); +// sloader_a.next(); +// sloader_b.next(); +// } +// +// // Allocate and zero the MMA accumulator +// RegisterTile C; +// C.fill(0); +// +// // Matmul loop +// int num_blocks = K / BK; +// int sread = 0; +// int swrite = PIPE - 1; +// for (int i = 0; i < num_blocks; i++) { +// cp_async_wait(); +// +// gemm_ab_t( +// C, as[sread], bs[sread], rloader_a, rloader_b); +// +// sloader_a.load_async(as[swrite].base_addr()); +// sloader_b.load_async(bs[swrite].base_addr()); +// cp_async_commit(); +// sloader_a.next(i + PIPE < num_blocks); +// sloader_b.next(i + PIPE < num_blocks); +// +// swrite = sread; +// sread = (sread + 1) % PIPE; +// } +// +// C.store_global(y, N, offset_m, offset_n); +// } + +template +__global__ __launch_bounds__( + WM* WN* WARP_SIZE, + 1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { + constexpr int NUM_WARPS = WM * WN; + constexpr int WARP_STEP_M = BM / WM; + constexpr int WARP_STEP_N = BN / WN; // Precompute some offsets for each thread const int warpid = threadIdx.x / 32; const int laneid = threadIdx.x % 32; - const int wm = warpid / WARPS_N; - const int wn = warpid % WARPS_N; + const int wm = warpid / WN; + const int wn = warpid % WN; const int offset_m = wm * WARP_STEP_M; const int offset_n = wn * WARP_STEP_N; - const int lane_row_a = offset_m + (laneid & 15); - const int lane_row_b = offset_n + (laneid & 15); - const int lane_col = (laneid >> 4) << 3; // Allocate shared memory extern __shared__ char shmem[]; @@ -65,34 +130,59 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { b += blockIdx.x * BN * K; y += blockIdx.y * BM * N + blockIdx.x * BN; + // Make the loaders to/from SMEM + using sloader = SharedTileLoader>; + constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK; + const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW; + const int scol = + (threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD; + a += srow * K + scol; + b += srow * K + scol; + uint32_t sm_offsets[PIPE][2]; + MLX_UNROLL + for (int s = 0; s < PIPE; s++) { + sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol); + sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol); + } + RegisterTileLoader> rloader_a(offset_m, laneid); + RegisterTileLoader> rloader_b(offset_n, laneid); + // Start the SM pipeline MLX_UNROLL - for (int i = 0; i < PIPE - 1; i++) { - load_async(as[i], as[i].base_addr(), a + i * BK, K); - load_async(bs[i], bs[i].base_addr(), b + i * BK, K); + for (int s = 0; s < PIPE - 1; s++) { + MLX_UNROLL + for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) { + cp_async<16>(sm_offsets[s][0] + l * SSTEP, a); + cp_async<16>(sm_offsets[s][1] + l * SSTEP, b); + a += sloader::STEP_ROWS * K; + b += sloader::STEP_ROWS * K; + } cp_async_commit(); } // Allocate and zero the MMA accumulator - RegisterTile C; + RegisterTile C; C.fill(0); // Matmul loop int num_blocks = K / BK; - int k_block = (PIPE - 1) * BK; int sread = 0; int swrite = PIPE - 1; for (int i = 0; i < num_blocks; i++) { - cp_async_wait(); + cp_async_wait(); - if (k_block < K) { - load_async(as[swrite], as[swrite].base_addr(), a + k_block, K); - load_async(bs[swrite], bs[swrite].base_addr(), b + k_block, K); + gemm_ab_t( + C, as[sread], bs[sread], rloader_a, rloader_b); + + if (false) { + MLX_UNROLL + for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) { + cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a); + cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b); + a += sloader::STEP_ROWS * K; + b += sloader::STEP_ROWS * K; + } } - - gemm_ab_t( - C, as[sread], bs[sread], lane_row_a, lane_row_b, lane_col); - cp_async_commit(); swrite = sread; diff --git a/mlx/backend/cuda/steel/tiles.cuh b/mlx/backend/cuda/steel/tiles.cuh index a2c51ef20..08eedb1df 100644 --- a/mlx/backend/cuda/steel/tiles.cuh +++ b/mlx/backend/cuda/steel/tiles.cuh @@ -225,6 +225,8 @@ struct RegisterTile { template struct SharedTile { + using value_type = T; + static constexpr int ROWS = ROWS_; static constexpr int COLS = COLS_; static constexpr int TILES_X = COLS / 16; @@ -266,23 +268,26 @@ struct SharedTile { } } - // Return the location of the element at (row, col) using the swizzle. - __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { + __device__ static inline uint32_t offset(int row, int col) { if constexpr (swizzle_bytes > 0) { static constexpr int swizzle_repeat = swizzle_bytes * 8; static constexpr int subtile_cols = swizzle_bytes / sizeof(T); const int outer_idx = col / subtile_cols; - const uint32_t addr = ptr + - sizeof(T) * - (outer_idx * ROWS * subtile_cols + row * subtile_cols + - col % subtile_cols); + const uint32_t addr = sizeof(T) * + (outer_idx * ROWS * subtile_cols + row * subtile_cols + + col % subtile_cols); const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; return (addr ^ swizzle); } else { - return ptr + sizeof(T) * (row * COLS + col); + return sizeof(T) * (row * COLS + col); } } + // Return the location of the element at (row, col) using the swizzle. + __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { + return ptr + offset(row, col); + } + // Convenience functions to edit elements going through the swizzle. __device__ inline T& operator()(int row, int col) { return *ptr(data, row, col); @@ -313,6 +318,76 @@ struct SharedTile { } }; +template +struct SharedTileLoader { + using T = typename Tile::value_type; + + static constexpr int NUM_THREADS = NUM_WARPS * 32; + static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const T* x_; + int N_; + uint32_t offset_; + + __device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) { + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x_ += row * N + col * ELEMENTS_PER_LOAD; + offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD); + } + + __device__ inline void load_async(uint32_t base_address) { + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + cp_async<16>( + base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS, + x_ + i * STEP_ROWS * N_); + } + } + + __device__ inline void next() { + x_ += Tile::COLS; + } +}; + +template +struct RegisterTileLoader { + using T = typename Tile::value_type; + + uint32_t offset_[Tile::COLS / 16]; + + __device__ RegisterTileLoader(int offset_row, int laneid) { + const int row = offset_row + laneid & 15; + const int col = (laneid >> 4) << 3; + + MLX_UNROLL + for (int i = 0; i < Tile::COLS / 16; i++) { + offset_[i] = Tile::offset(row, col + i * 16); + } + } + + template + __device__ inline void + load(RegisterTile& x, uint32_t base_address, int col) { + constexpr int TILES_Y = RegisterTile::TILES_Y; + constexpr int TILES_X = RegisterTile::TILES_X; + + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + x.data[i * TILES_X + j].load( + base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T)); + } + } + } +}; + /** * Load the tile from global memory by loading 16 bytes at a time and storing * them immediately. diff --git a/mlx/backend/cuda/steel/utils.cuh b/mlx/backend/cuda/steel/utils.cuh index 0957c09d0..c1f25d1a0 100644 --- a/mlx/backend/cuda/steel/utils.cuh +++ b/mlx/backend/cuda/steel/utils.cuh @@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) { #if defined(MLX_CUDA_SM_80_ENABLED) if constexpr (N == 16) { asm volatile( - "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), + "cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), "l"(reinterpret_cast(x))); } else if constexpr (N == 8) { asm volatile( - "cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), + "cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), "l"(reinterpret_cast(x))); } else if constexpr (N == 4) { asm volatile( - "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), + "cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), "l"(reinterpret_cast(x))); } #endif