diff --git a/mlx/backend/cuda/matmul/tiles.cuh b/mlx/backend/cuda/matmul/tiles.cuh index bd0a240b8..b04cd29e6 100644 --- a/mlx/backend/cuda/matmul/tiles.cuh +++ b/mlx/backend/cuda/matmul/tiles.cuh @@ -122,6 +122,25 @@ struct Tile16x16 { __floats2bfloat162_rn(values[3].x, values[3].y); } } + + template + __device__ inline void store_global_safe(U* x, int N, int max_rows) { + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + if (row < max_rows) { + x[(row + 0) * N + 2 * col + 0] = static_cast(values[0].x); + x[(row + 0) * N + 2 * col + 1] = static_cast(values[0].y); + x[(row + 0) * N + 2 * col + 8] = static_cast(values[2].x); + x[(row + 0) * N + 2 * col + 9] = static_cast(values[2].y); + } + if (row + 8 < max_rows) { + x[(row + 8) * N + 2 * col + 0] = static_cast(values[1].x); + x[(row + 8) * N + 2 * col + 1] = static_cast(values[1].y); + x[(row + 8) * N + 2 * col + 8] = static_cast(values[3].x); + x[(row + 8) * N + 2 * col + 9] = static_cast(values[3].y); + } + } }; /** @@ -173,6 +192,19 @@ struct RegisterTile { } } } + + template + __device__ inline void + store_global_safe(U* x, int N, int row, int col, int max_rows) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].store_global_safe( + x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16); + } + } + } }; template @@ -352,4 +384,36 @@ load_async(Tile& tile, uint32_t base_address, const T* x, int N) { } } +template +__device__ inline void load_async_safe( + Tile& tile, + uint32_t base_address, + const T* x, + int N, + int max_rows) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + if (row + i * STEP_ROWS < max_rows) { + cp_async_16( + tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD), + x + i * STEP_ROWS * N); + } else { + float4 tmp = {0, 0, 0, 0}; + tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); + } + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/quantized/qmm.cu b/mlx/backend/cuda/quantized/qmm.cu index 6e6c772a2..c96897d5d 100644 --- a/mlx/backend/cuda/quantized/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm.cu @@ -50,8 +50,15 @@ __device__ inline void load_quantized( } } -template -__global__ void qmm_t_aligned( +template < + typename T, + int BM, + int BN, + int BK, + int group_size, + int bits, + bool aligned_M> +__global__ void qmm_t( const T* x, const uint8_t* w, const T* scales, @@ -82,6 +89,8 @@ __global__ void qmm_t_aligned( RegisterTile A; RegisterTile B; + const int max_rows = M - blockIdx.y * BM; + x += blockIdx.y * BM * K; w += blockIdx.x * BN * K / get_pack_factor(); scales += blockIdx.x * BN * K / group_size; @@ -95,36 +104,66 @@ __global__ void qmm_t_aligned( base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]); base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]); - for (int k_block = 0; k_block < K; k_block += BK) { - load_async(xs[tic], base_addr_xs[tic], x + k_block, K); - cp_async_commit(); - // load(xs[tic], x + k_block, K); - load_quantized( - ws[tic], - w + k_block / get_pack_factor(), - scales + k_block / group_size, - biases + k_block / group_size, - K); - cp_async_wait_all(); - __syncthreads(); - - MLX_UNROLL - for (int k = 0; k < BK / 16; k++) { - A.load( - xs[tic], - base_addr_xs[tic], - offset_m + laneid % 16, - k * 16 + laneid / 16 * 8); - B.load( + if (aligned_M || max_rows >= BM) { + for (int k_block = 0; k_block < K; k_block += BK) { + load_async(xs[tic], base_addr_xs[tic], x + k_block, K); + cp_async_commit(); + load_quantized( ws[tic], - base_addr_ws[tic], - offset_n + laneid % 16, - k * 16 + laneid / 16 * 8); - mma_t(C, A, B); - } - } + w + k_block / get_pack_factor(), + scales + k_block / group_size, + biases + k_block / group_size, + K); + cp_async_wait_all(); + __syncthreads(); - C.store_global(y, N, offset_m, offset_n); + MLX_UNROLL + for (int k = 0; k < BK / 16; k++) { + A.load( + xs[tic], + base_addr_xs[tic], + offset_m + laneid % 16, + k * 16 + laneid / 16 * 8); + B.load( + ws[tic], + base_addr_ws[tic], + offset_n + laneid % 16, + k * 16 + laneid / 16 * 8); + mma_t(C, A, B); + } + } + C.store_global(y, N, offset_m, offset_n); + } else { + for (int k_block = 0; k_block < K; k_block += BK) { + load_async_safe( + xs[tic], base_addr_xs[tic], x + k_block, K, max_rows); + cp_async_commit(); + load_quantized( + ws[tic], + w + k_block / get_pack_factor(), + scales + k_block / group_size, + biases + k_block / group_size, + K); + cp_async_wait_all(); + __syncthreads(); + + MLX_UNROLL + for (int k = 0; k < BK / 16; k++) { + A.load( + xs[tic], + base_addr_xs[tic], + offset_m + laneid % 16, + k * 16 + laneid / 16 * 8); + B.load( + ws[tic], + base_addr_ws[tic], + offset_n + laneid % 16, + k * 16 + laneid / 16 * 8); + mma_t(C, A, B); + } + } + C.store_global_safe(y, N, offset_m, offset_n, max_rows); + } } } // namespace cu @@ -143,17 +182,30 @@ void qmm( int K, cu::CommandEncoder& enc, const Stream& s) { + if (x.dtype() != bfloat16) { + throw std::invalid_argument("[qmm] Only bfloat16 is supported for now"); + } + if (!transpose_) { + throw std::invalid_argument( + "[qmm] Only transposed matmul is supported for now"); + } + dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) { dispatch_groups(group_size_, [&](auto group_size) { dispatch_bits(bits_, [&](auto bits) { using DataType = cuda_type_t; + constexpr int BM = 128; constexpr int BN = 128; constexpr int BK = 32; - auto kernel = cu:: - qmm_t_aligned; + auto kernel = + cu::qmm_t; + if (M % BM != 0) { + kernel = cu:: + qmm_t; + } - dim3 grid(N / BN, M / BM); + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); enc.add_kernel_node( kernel,