Support unaligned M

This commit is contained in:
Angelos Katharopoulos
2025-07-23 00:40:27 -07:00
parent 903b40627c
commit 8269c9d02d
2 changed files with 149 additions and 33 deletions

View File

@@ -122,6 +122,25 @@ struct Tile16x16 {
__floats2bfloat162_rn(values[3].x, values[3].y);
}
}
template <typename U>
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if (row < max_rows) {
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
}
if (row + 8 < max_rows) {
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
}
}
};
/**
@@ -173,6 +192,19 @@ struct RegisterTile {
}
}
}
template <typename U>
__device__ inline void
store_global_safe(U* x, int N, int row, int col, int max_rows) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global_safe(
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
}
}
}
};
template <typename T, int ROWS_, int COLS_>
@@ -352,4 +384,36 @@ load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
}
}
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load_async_safe(
Tile& tile,
uint32_t base_address,
const T* x,
int N,
int max_rows) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
if (row + i * STEP_ROWS < max_rows) {
cp_async_16(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
} else {
float4 tmp = {0, 0, 0, 0};
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
}
} // namespace mlx::core::cu

View File

@@ -50,8 +50,15 @@ __device__ inline void load_quantized(
}
}
template <typename T, int BM, int BN, int BK, int group_size, int bits>
__global__ void qmm_t_aligned(
template <
typename T,
int BM,
int BN,
int BK,
int group_size,
int bits,
bool aligned_M>
__global__ void qmm_t(
const T* x,
const uint8_t* w,
const T* scales,
@@ -82,6 +89,8 @@ __global__ void qmm_t_aligned(
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
const int max_rows = M - blockIdx.y * BM;
x += blockIdx.y * BM * K;
w += blockIdx.x * BN * K / get_pack_factor<bits>();
scales += blockIdx.x * BN * K / group_size;
@@ -95,10 +104,10 @@ __global__ void qmm_t_aligned(
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
if (aligned_M || max_rows >= BM) {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
cp_async_commit();
// load<NUM_WARPS>(xs[tic], x + k_block, K);
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
@@ -123,8 +132,38 @@ __global__ void qmm_t_aligned(
mma_t(C, A, B);
}
}
C.store_global(y, N, offset_m, offset_n);
} else {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async_safe<NUM_WARPS>(
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
cp_async_commit();
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
xs[tic],
base_addr_xs[tic],
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
ws[tic],
base_addr_ws[tic],
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
}
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
}
}
} // namespace cu
@@ -143,17 +182,30 @@ void qmm(
int K,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.dtype() != bfloat16) {
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
}
if (!transpose_) {
throw std::invalid_argument(
"[qmm] Only transposed matmul is supported for now");
}
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
auto kernel = cu::
qmm_t_aligned<DataType, BM, BN, BK, group_size.value, bits.value>;
auto kernel =
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
if (M % BM != 0) {
kernel = cu::
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
}
dim3 grid(N / BN, M / BM);
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
enc.add_kernel_node(
kernel,