mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 09:18:11 +08:00
Support unaligned M
This commit is contained in:
@@ -122,6 +122,25 @@ struct Tile16x16 {
|
||||
__floats2bfloat162_rn(values[3].x, values[3].y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
|
||||
const int laneid = threadIdx.x % 32;
|
||||
const int row = laneid / 4;
|
||||
const int col = laneid % 4;
|
||||
if (row < max_rows) {
|
||||
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
|
||||
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
|
||||
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
|
||||
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
|
||||
}
|
||||
if (row + 8 < max_rows) {
|
||||
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
|
||||
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
|
||||
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
|
||||
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -173,6 +192,19 @@ struct RegisterTile {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void
|
||||
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global_safe(
|
||||
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
@@ -352,4 +384,36 @@ load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
|
||||
}
|
||||
}
|
||||
|
||||
template <int NUM_WARPS, typename T, typename Tile>
|
||||
__device__ inline void load_async_safe(
|
||||
Tile& tile,
|
||||
uint32_t base_address,
|
||||
const T* x,
|
||||
int N,
|
||||
int max_rows) {
|
||||
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||
|
||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||
|
||||
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||
if (row + i * STEP_ROWS < max_rows) {
|
||||
cp_async_16(
|
||||
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||
x + i * STEP_ROWS * N);
|
||||
} else {
|
||||
float4 tmp = {0, 0, 0, 0};
|
||||
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -50,8 +50,15 @@ __device__ inline void load_quantized(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BM, int BN, int BK, int group_size, int bits>
|
||||
__global__ void qmm_t_aligned(
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool aligned_M>
|
||||
__global__ void qmm_t(
|
||||
const T* x,
|
||||
const uint8_t* w,
|
||||
const T* scales,
|
||||
@@ -82,6 +89,8 @@ __global__ void qmm_t_aligned(
|
||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||
|
||||
const int max_rows = M - blockIdx.y * BM;
|
||||
|
||||
x += blockIdx.y * BM * K;
|
||||
w += blockIdx.x * BN * K / get_pack_factor<bits>();
|
||||
scales += blockIdx.x * BN * K / group_size;
|
||||
@@ -95,36 +104,66 @@ __global__ void qmm_t_aligned(
|
||||
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
|
||||
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
|
||||
|
||||
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
|
||||
cp_async_commit();
|
||||
// load<NUM_WARPS>(xs[tic], x + k_block, K);
|
||||
load_quantized<NUM_WARPS, group_size, bits>(
|
||||
ws[tic],
|
||||
w + k_block / get_pack_factor<bits>(),
|
||||
scales + k_block / group_size,
|
||||
biases + k_block / group_size,
|
||||
K);
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
xs[tic],
|
||||
base_addr_xs[tic],
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
if (aligned_M || max_rows >= BM) {
|
||||
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
|
||||
cp_async_commit();
|
||||
load_quantized<NUM_WARPS, group_size, bits>(
|
||||
ws[tic],
|
||||
base_addr_ws[tic],
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
}
|
||||
w + k_block / get_pack_factor<bits>(),
|
||||
scales + k_block / group_size,
|
||||
biases + k_block / group_size,
|
||||
K);
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
|
||||
C.store_global(y, N, offset_m, offset_n);
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
xs[tic],
|
||||
base_addr_xs[tic],
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
ws[tic],
|
||||
base_addr_ws[tic],
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
}
|
||||
C.store_global(y, N, offset_m, offset_n);
|
||||
} else {
|
||||
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||
load_async_safe<NUM_WARPS>(
|
||||
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
|
||||
cp_async_commit();
|
||||
load_quantized<NUM_WARPS, group_size, bits>(
|
||||
ws[tic],
|
||||
w + k_block / get_pack_factor<bits>(),
|
||||
scales + k_block / group_size,
|
||||
biases + k_block / group_size,
|
||||
K);
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
xs[tic],
|
||||
base_addr_xs[tic],
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
ws[tic],
|
||||
base_addr_ws[tic],
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
}
|
||||
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
@@ -143,17 +182,30 @@ void qmm(
|
||||
int K,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (x.dtype() != bfloat16) {
|
||||
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
|
||||
}
|
||||
if (!transpose_) {
|
||||
throw std::invalid_argument(
|
||||
"[qmm] Only transposed matmul is supported for now");
|
||||
}
|
||||
|
||||
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
constexpr int BM = 128;
|
||||
constexpr int BN = 128;
|
||||
constexpr int BK = 32;
|
||||
auto kernel = cu::
|
||||
qmm_t_aligned<DataType, BM, BN, BK, group_size.value, bits.value>;
|
||||
auto kernel =
|
||||
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
|
||||
if (M % BM != 0) {
|
||||
kernel = cu::
|
||||
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
|
||||
}
|
||||
|
||||
dim3 grid(N / BN, M / BM);
|
||||
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
|
||||
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
|
||||
Reference in New Issue
Block a user