mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add the metal version of segmented mm
This commit is contained in:
@@ -71,6 +71,7 @@ set(STEEL_HEADERS
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_gather.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_segmented.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
@@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool segments_contiguous [[function_constant(199)]];
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device int32_t* segments [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Find the block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Move the pointers to the output tile
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Move the pointers to the start of the segment
|
||||
int32_t k_start, k_end;
|
||||
if (segments_contiguous) {
|
||||
k_start = segments[2 * tid.z];
|
||||
k_end = segments[2 * tid.z + 1];
|
||||
} else {
|
||||
// We accept either contiguous (above) or weird strides where the beginning
|
||||
// of the next one is the previous one. Basically the last two strides are
|
||||
// both 1!
|
||||
k_start = segments[tid.z];
|
||||
k_end = segments[tid.z + 1];
|
||||
}
|
||||
A += transpose_a ? k_start * params->lda : k_start;
|
||||
B += transpose_b ? k_start : k_start * params->ldb;
|
||||
C += tid.z * params->batch_stride_d;
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Matrix level alignment so only check K
|
||||
if (align_M && align_N) {
|
||||
int k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
// Tile aligned do the same as above
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
int k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result(C, params->ldd);
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
int k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
int k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_safe(
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
int k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||
loader_b.load_safe(
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = k_end - (k - BK);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h"
|
||||
|
||||
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
segmented_mm, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
// clang-format on
|
||||
|
||||
instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_segmented_mm_shapes_helper(
|
||||
bfloat16,
|
||||
bfloat16_t,
|
||||
bfloat16,
|
||||
bfloat16_t);
|
||||
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);
|
||||
@@ -1874,11 +1874,37 @@ void segmented_mm(
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto check_segments_layout = [&d, &s](const array& x) {
|
||||
// Contiguous so return early
|
||||
if (x.flags().row_contiguous) {
|
||||
return std::make_tuple(true, x);
|
||||
}
|
||||
|
||||
bool rc = true;
|
||||
for (int i = 0; i < x.ndim() - 2; i++) {
|
||||
rc &=
|
||||
(x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
|
||||
}
|
||||
rc &= x.strides(x.ndim() - 1) == 1;
|
||||
if (x.ndim() > 1) {
|
||||
rc &= x.strides(x.ndim() - 2) == 1;
|
||||
}
|
||||
|
||||
if (rc) {
|
||||
return std::make_tuple(false, x);
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return std::make_tuple(true, x_copy);
|
||||
};
|
||||
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||
auto segments = ensure_row_contiguous(segments_, d, s);
|
||||
auto [segments_contiguous, segments] = check_segments_layout(segments_);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
// Determine dispatch kernel
|
||||
@@ -1916,6 +1942,7 @@ void segmented_mm(
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
@@ -1926,6 +1953,8 @@ void segmented_mm(
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_segments_contiguous_",
|
||||
segments_contiguous ? 't' : 'n',
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
|
||||
Reference in New Issue
Block a user