Compare commits

..

1 Commits

Author SHA1 Message Date
Awni Hannun
1034009b82 faster general unary op 2025-08-14 19:44:46 -07:00
6 changed files with 113 additions and 437 deletions

View File

@@ -24,7 +24,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/steel_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu

View File

@@ -1,301 +0,0 @@
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/gemms/steel_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
#include <cooperative_groups.h>
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/backend/cuda/steel/mma.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
struct GemmParams {
int M;
int N;
int K;
int lda;
int ldb;
int ldd;
int NblockM;
int NblockN;
int NblockK;
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int SL,
int Nstages>
__global__ void kernel_steel_gemm(
const T* a,
const T* b,
T* d,
__grid_constant__ const GemmParams params) {
const int bM_idx = (blockIdx.y << SL) + (blockIdx.x & ((1 << SL) - 1));
const int bN_idx = blockIdx.x >> SL;
if (params.NblockN <= bN_idx || params.NblockM <= bM_idx) {
return;
}
const int d_row = bM_idx * BM;
const int d_col = bN_idx * BN;
const size_t d_row_long = size_t(d_row);
const size_t d_col_long = size_t(d_col);
a += transpose_a ? d_row_long : d_row_long * params.K;
b += transpose_b ? d_col_long * params.K : d_col_long;
d += d_row_long * params.ldd + d_col_long;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
const int wm = warp_idx / WN;
const int wn = warp_idx % WN;
constexpr int SM = BM / WM;
constexpr int SN = BN / WN;
constexpr int SK = BK;
constexpr int TK = SK / 16;
constexpr int NUM_WARPS = WM * WN;
// Allocate shared memory
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[Nstages] =
*(SharedTile<T, BM, BK>(*)[Nstages])(&shmem[0]);
SharedTile<T, BN, BK>(&bs)[Nstages] = *(SharedTile<T, BN, BK>(*)[Nstages])(
&shmem[sizeof(T) * Nstages * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, SM, SN> C;
RegisterTile<T, SM, 16> A[TK];
RegisterTile<T, SN, 16> B[TK];
// Zero the accumulators
C.fill(0);
// Start gmem -> smem copies
int k_block_read = 0;
MLX_UNROLL
for (int bk = 0; bk < (Nstages - 1); bk++) {
load_async<NUM_WARPS>(
as[bk], as[bk].base_addr(), a + k_block_read, params.K);
load_async<NUM_WARPS>(
bs[bk], bs[bk].base_addr(), b + k_block_read, params.K);
k_block_read += BK;
cp_async_commit();
}
int smem_pipe_read = 0;
int smem_pipe_write = Nstages - 1;
// Wait till only 1 remains laoding
cp_async_wait<1>();
block.sync();
const int offset_m = wm * SM;
const int offset_n = wn * SN;
// Start smem -> register copy
A[0].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
lane_idx / 16 * 8);
B[0].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
lane_idx / 16 * 8);
// Main loop
for (int kb = 0; kb < params.NblockK; kb++) {
// Prepare next registers
{
A[1].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
16 + lane_idx / 16 * 8);
B[1].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
16 + lane_idx / 16 * 8);
}
// Prepare next smem
if ((kb + Nstages - 1) < params.NblockK) {
load_async<NUM_WARPS>(
as[smem_pipe_write],
as[smem_pipe_write].base_addr(),
a + k_block_read,
params.K);
load_async<NUM_WARPS>(
bs[smem_pipe_write],
bs[smem_pipe_write].base_addr(),
b + k_block_read,
params.K);
}
k_block_read += BK;
cp_async_commit();
smem_pipe_write = smem_pipe_read;
smem_pipe_read = smem_pipe_read + 1;
smem_pipe_read = (smem_pipe_read == Nstages) ? 0 : smem_pipe_read;
// Do current gemm
mma_t(C, A[0], B[0]);
// Do wait for next register
cp_async_wait<1>();
block.sync();
// Prepare next register (smem_pipe_read has moved to the next)
{
A[0].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
lane_idx / 16 * 8);
B[0].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
lane_idx / 16 * 8);
}
// Do current gemm
mma_t(C, A[1], B[1]);
}
// Wait and clear
cp_async_wait_all();
block.sync();
C.store_global(d, params.ldd, offset_m, offset_n);
}
} // namespace cu
void dispatch_steel_gemm(
const Stream& s,
cu::CommandEncoder& encoder,
const array& a,
const array& b,
array& d,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool a_transposed,
bool b_transposed) {
using DataType = cuda_type_t<float16_t>;
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(d);
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int SL = 0;
constexpr int Nstages = 3;
constexpr uint32_t smem_bytes = BK * (BM + BN) * Nstages * sizeof(DataType);
const int NblockM = (M + BM - 1) / BM;
const int NblockN = (N + BN - 1) / BN;
const int NblockK = (K + BK - 1) / BK;
cu::GemmParams params{
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* int NblockM = */ NblockM,
/* int NblockN = */ NblockN,
/* int NblockK = */ NblockK,
};
// Prepare launch grid params
int tile = 1 << SL;
int tm = (NblockM + tile - 1) / tile;
int tn = NblockN * tile;
dim3 grid_dim(tn, tm, 1);
dim3 block_dim(32 * WM * WN, 1, 1);
dispatch_bool(a_transposed, [&](auto ta_) {
dispatch_bool(b_transposed, [&](auto tb_) {
constexpr bool ta = ta_.value;
constexpr bool tb = tb_.value;
auto kernel = cu::ab_t_aligned<DataType, BM, BN, BK>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
smem_bytes,
a.data<DataType>(),
b.data<DataType>(),
d.data<DataType>(),
N,
K);
// auto kernel = cu::kernel_steel_gemm<DataType, BM, BN, BK, WM, WN, ta,
// tb, SL, Nstages>;
// cudaFuncSetAttribute(kernel,
// cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
// encoder.add_kernel_node(
// kernel,
// grid_dim,
// block_dim,
// smem_bytes,
// a.data<DataType>(),
// b.data<DataType>(),
// d.data<DataType>(),
// params);
});
});
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
#pragma once
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
namespace mlx::core {
void dispatch_steel_gemm(
const Stream& s,
cu::CommandEncoder& encoder,
const array& a,
const array& b,
array& d,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool a_transposed,
bool b_transposed);
} // namespace mlx::core

View File

@@ -7,8 +7,6 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
#include "mlx/backend/cuda/gemms/steel_gemm.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
@@ -97,24 +95,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
if (out.dtype() == float16 && batch_count == 1 && !a_transposed &&
b_transposed) {
return dispatch_steel_gemm(
/* const Stream& s = */ s,
/* cu::CommandEncoder& encoder = */ encoder,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& d = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ N,
/* bool a_transposed = */ a_transposed,
/* bool b_transposed = */ b_transposed);
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(

View File

@@ -143,87 +143,85 @@ struct Tile16x16 {
}
};
// /**
// * A simple container of multiple Tile16x16.
// *
// * Provides utility functions for loading and manipulating collections of
// basic
// * tiles.
// */
// template <typename T, int ROWS_, int COLS_>
// struct RegisterTile {
// static constexpr int ROWS = ROWS_;
// static constexpr int COLS = COLS_;
// static constexpr int TILES_X = COLS / 16;
// static constexpr int TILES_Y = ROWS / 16;
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
// Tile16x16<T> data[TILES_X * TILES_Y];
Tile16x16<T> data[TILES_X * TILES_Y];
// __device__ inline void fill(T v) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].fill(v);
// }
// }
// }
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
// template <typename Tile>
// __device__ __forceinline__ void
// load(Tile& tile, uint32_t base_address, int row, int col) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].load(
// tile.loc(base_address, row + i * 16, col + j * 16));
// }
// }
// }
template <typename Tile>
__device__ __forceinline__ void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
// template <typename Tile, typename F>
// __device__ __forceinline__ void
// load(Tile& tile, F f, uint32_t base_address, int row, int col) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// f(data[i * TILES_X + j],
// tile,
// base_address,
// row + i * 16,
// col + j * 16);
// }
// }
// }
template <typename Tile, typename F>
__device__ __forceinline__ void
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
f(data[i * TILES_X + j],
tile,
base_address,
row + i * 16,
col + j * 16);
}
}
}
// template <typename U>
// __device__ inline void store_global(U* x, int N, int row, int col) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].store_global(
// x + (row + i * 16) * N + col + j * 16, N);
// }
// }
// }
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
// template <typename U>
// __device__ inline void
// store_global_safe(U* x, int N, int row, int col, int max_rows) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].store_global_safe(
// x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i *
// 16);
// }
// }
// }
// };
template <typename U>
__device__ inline void
store_global_safe(U* x, int N, int row, int col, int max_rows) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global_safe(
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
}
}
}
};
/**
* A simple container of multiple Tile16x16.

View File

@@ -37,19 +37,35 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx = elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec = load_vector<N_READS>(
in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -127,8 +143,7 @@ void unary_op_gpu_inplace(
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
constexpr int N_READS = 16 / sizeof(OutType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
@@ -142,18 +157,30 @@ void unary_op_gpu_inplace(
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto [num_blocks, block_dims] = get_launch_args(out, large);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
encoder.add_kernel_node(
cu::unary_g<Op, InType, OutType, IdxT>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size(),
rest,
const_param(shape),
const_param(strides),
shape.size());
ndim);
}
});
} else {