mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 01:48:12 +08:00
Support unaligned M
This commit is contained in:
@@ -122,6 +122,25 @@ struct Tile16x16 {
|
|||||||
__floats2bfloat162_rn(values[3].x, values[3].y);
|
__floats2bfloat162_rn(values[3].x, values[3].y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
|
||||||
|
const int laneid = threadIdx.x % 32;
|
||||||
|
const int row = laneid / 4;
|
||||||
|
const int col = laneid % 4;
|
||||||
|
if (row < max_rows) {
|
||||||
|
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
|
||||||
|
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
|
||||||
|
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
|
||||||
|
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
|
||||||
|
}
|
||||||
|
if (row + 8 < max_rows) {
|
||||||
|
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
|
||||||
|
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
|
||||||
|
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
|
||||||
|
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -173,6 +192,19 @@ struct RegisterTile {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void
|
||||||
|
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global_safe(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int ROWS_, int COLS_>
|
template <typename T, int ROWS_, int COLS_>
|
||||||
@@ -352,4 +384,36 @@ load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int NUM_WARPS, typename T, typename Tile>
|
||||||
|
__device__ inline void load_async_safe(
|
||||||
|
Tile& tile,
|
||||||
|
uint32_t base_address,
|
||||||
|
const T* x,
|
||||||
|
int N,
|
||||||
|
int max_rows) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
if (row + i * STEP_ROWS < max_rows) {
|
||||||
|
cp_async_16(
|
||||||
|
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||||
|
x + i * STEP_ROWS * N);
|
||||||
|
} else {
|
||||||
|
float4 tmp = {0, 0, 0, 0};
|
||||||
|
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -50,8 +50,15 @@ __device__ inline void load_quantized(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int BM, int BN, int BK, int group_size, int bits>
|
template <
|
||||||
__global__ void qmm_t_aligned(
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool aligned_M>
|
||||||
|
__global__ void qmm_t(
|
||||||
const T* x,
|
const T* x,
|
||||||
const uint8_t* w,
|
const uint8_t* w,
|
||||||
const T* scales,
|
const T* scales,
|
||||||
@@ -82,6 +89,8 @@ __global__ void qmm_t_aligned(
|
|||||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||||
|
|
||||||
|
const int max_rows = M - blockIdx.y * BM;
|
||||||
|
|
||||||
x += blockIdx.y * BM * K;
|
x += blockIdx.y * BM * K;
|
||||||
w += blockIdx.x * BN * K / get_pack_factor<bits>();
|
w += blockIdx.x * BN * K / get_pack_factor<bits>();
|
||||||
scales += blockIdx.x * BN * K / group_size;
|
scales += blockIdx.x * BN * K / group_size;
|
||||||
@@ -95,36 +104,66 @@ __global__ void qmm_t_aligned(
|
|||||||
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
|
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
|
||||||
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
|
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
|
||||||
|
|
||||||
for (int k_block = 0; k_block < K; k_block += BK) {
|
if (aligned_M || max_rows >= BM) {
|
||||||
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
|
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||||
cp_async_commit();
|
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
|
||||||
// load<NUM_WARPS>(xs[tic], x + k_block, K);
|
cp_async_commit();
|
||||||
load_quantized<NUM_WARPS, group_size, bits>(
|
load_quantized<NUM_WARPS, group_size, bits>(
|
||||||
ws[tic],
|
|
||||||
w + k_block / get_pack_factor<bits>(),
|
|
||||||
scales + k_block / group_size,
|
|
||||||
biases + k_block / group_size,
|
|
||||||
K);
|
|
||||||
cp_async_wait_all();
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
MLX_UNROLL
|
|
||||||
for (int k = 0; k < BK / 16; k++) {
|
|
||||||
A.load(
|
|
||||||
xs[tic],
|
|
||||||
base_addr_xs[tic],
|
|
||||||
offset_m + laneid % 16,
|
|
||||||
k * 16 + laneid / 16 * 8);
|
|
||||||
B.load(
|
|
||||||
ws[tic],
|
ws[tic],
|
||||||
base_addr_ws[tic],
|
w + k_block / get_pack_factor<bits>(),
|
||||||
offset_n + laneid % 16,
|
scales + k_block / group_size,
|
||||||
k * 16 + laneid / 16 * 8);
|
biases + k_block / group_size,
|
||||||
mma_t(C, A, B);
|
K);
|
||||||
}
|
cp_async_wait_all();
|
||||||
}
|
__syncthreads();
|
||||||
|
|
||||||
C.store_global(y, N, offset_m, offset_n);
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
xs[tic],
|
||||||
|
base_addr_xs[tic],
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
ws[tic],
|
||||||
|
base_addr_ws[tic],
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
mma_t(C, A, B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
C.store_global(y, N, offset_m, offset_n);
|
||||||
|
} else {
|
||||||
|
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||||
|
load_async_safe<NUM_WARPS>(
|
||||||
|
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
|
||||||
|
cp_async_commit();
|
||||||
|
load_quantized<NUM_WARPS, group_size, bits>(
|
||||||
|
ws[tic],
|
||||||
|
w + k_block / get_pack_factor<bits>(),
|
||||||
|
scales + k_block / group_size,
|
||||||
|
biases + k_block / group_size,
|
||||||
|
K);
|
||||||
|
cp_async_wait_all();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
xs[tic],
|
||||||
|
base_addr_xs[tic],
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
ws[tic],
|
||||||
|
base_addr_ws[tic],
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
mma_t(C, A, B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
@@ -143,17 +182,30 @@ void qmm(
|
|||||||
int K,
|
int K,
|
||||||
cu::CommandEncoder& enc,
|
cu::CommandEncoder& enc,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
|
if (x.dtype() != bfloat16) {
|
||||||
|
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
|
||||||
|
}
|
||||||
|
if (!transpose_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[qmm] Only transposed matmul is supported for now");
|
||||||
|
}
|
||||||
|
|
||||||
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
|
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
|
||||||
dispatch_groups(group_size_, [&](auto group_size) {
|
dispatch_groups(group_size_, [&](auto group_size) {
|
||||||
dispatch_bits(bits_, [&](auto bits) {
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
constexpr int BM = 128;
|
constexpr int BM = 128;
|
||||||
constexpr int BN = 128;
|
constexpr int BN = 128;
|
||||||
constexpr int BK = 32;
|
constexpr int BK = 32;
|
||||||
auto kernel = cu::
|
auto kernel =
|
||||||
qmm_t_aligned<DataType, BM, BN, BK, group_size.value, bits.value>;
|
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
|
||||||
|
if (M % BM != 0) {
|
||||||
|
kernel = cu::
|
||||||
|
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
|
||||||
|
}
|
||||||
|
|
||||||
dim3 grid(N / BN, M / BM);
|
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
|
||||||
|
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
|
|||||||
Reference in New Issue
Block a user