mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
MoE backward improvements (#2335)
This commit is contained in:

committed by
GitHub

parent
a4fcc893cd
commit
4a9b29a875
@@ -63,6 +63,7 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_segmented)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
kernels/steel/utils.h
|
||||
|
@@ -575,9 +575,17 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
if (ndim > 1) {
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
} else {
|
||||
// The following will be ignored in the kernel but we still have to set
|
||||
// some value so that metal validation passes.
|
||||
compute_encoder.set_vector_bytes(idx.shape(), 3);
|
||||
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||
compute_encoder.set_vector_bytes(idx.strides(), 5);
|
||||
}
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||
|
@@ -34,6 +34,7 @@ const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* steel_gemm_gather();
|
||||
const char* steel_gemm_segmented();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
@@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::steel_gemm_segmented(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
"segmented_mm",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
int wn,
|
||||
bool rhs);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -71,6 +71,7 @@ set(STEEL_HEADERS
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_gather.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_segmented.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
@@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
|
@@ -0,0 +1,266 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool segments_contiguous [[function_constant(199)]];
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* segments [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Find the block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Move the pointers to the output tile
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Move the pointers to the start of the segment
|
||||
uint32_t k_start, k_end;
|
||||
if (segments_contiguous) {
|
||||
k_start = segments[2 * tid.z];
|
||||
k_end = segments[2 * tid.z + 1];
|
||||
} else {
|
||||
// We accept either contiguous (above) or weird strides where the beginning
|
||||
// of the next one is the previous one. Basically the last two strides are
|
||||
// both 1!
|
||||
k_start = segments[tid.z];
|
||||
k_end = segments[tid.z + 1];
|
||||
}
|
||||
A += transpose_a ? k_start * params->lda : k_start;
|
||||
B += transpose_b ? k_start : k_start * params->ldb;
|
||||
C += tid.z * params->batch_stride_d;
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Matrix level alignment so only check K
|
||||
if (align_M && align_N) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result(C, params->ldd);
|
||||
} else {
|
||||
// Tile aligned do the same as above
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result(C, params->ldd);
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_safe(
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
uint32_t k = k_start + BK;
|
||||
for (; k <= k_end; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_safe(
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
||||
loader_b.load_safe(
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
short k_remain = BK - short(k - k_end);
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
if (k_remain > 0) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,43 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h"
|
||||
|
||||
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
segmented_mm, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
// clang-format on
|
||||
|
||||
instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_segmented_mm_shapes_helper(
|
||||
bfloat16,
|
||||
bfloat16_t,
|
||||
bfloat16,
|
||||
bfloat16_t);
|
||||
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);
|
@@ -1864,4 +1864,166 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
void segmented_mm(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& segments_,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto check_segments_layout = [&d, &s](const array& x) {
|
||||
// Contiguous so return early
|
||||
if (x.flags().row_contiguous) {
|
||||
return std::make_tuple(true, x);
|
||||
}
|
||||
|
||||
bool rc = true;
|
||||
for (int i = 0; i < x.ndim() - 2; i++) {
|
||||
rc &=
|
||||
(x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
|
||||
}
|
||||
rc &= x.strides(x.ndim() - 1) == 1;
|
||||
if (x.ndim() > 1) {
|
||||
rc &= x.strides(x.ndim() - 2) == 1;
|
||||
}
|
||||
|
||||
if (rc) {
|
||||
return std::make_tuple(false, x);
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return std::make_tuple(true, x_copy);
|
||||
};
|
||||
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||
auto [segments_contiguous, segments] = check_segments_layout(segments_);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
size_t batch_size_out = out.size() / M / N;
|
||||
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(128);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_segmented_mm_",
|
||||
transpose_a ? 't' : 'n',
|
||||
transpose_b ? 't' : 'n',
|
||||
"_",
|
||||
type_to_name(a),
|
||||
"_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_segments_contiguous_",
|
||||
segments_contiguous ? 't' : 'n',
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n');
|
||||
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_segmented_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Prepare the matmul params
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ static_cast<int>(lda),
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */ 0,
|
||||
/* const int64_t batch_stride_b = */ 0,
|
||||
/* const int64_t batch_stride_d = */ M * N,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ 0,
|
||||
/* const int batch_ndim = */ 0};
|
||||
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(segments, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& segments = inputs[2];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Extract shapes from inputs.
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
segmented_mm(a, b, segments, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array&,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
Reference in New Issue
Block a user