Some cleanup

This commit is contained in:
Jagrit Digani 2025-06-11 09:43:34 -07:00
parent 4b02d3e738
commit b1d95a3880
2 changed files with 37 additions and 36 deletions

View File

@ -33,8 +33,8 @@ template <
device T* D [[buffer(3)]], device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]], const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],

View File

@ -164,6 +164,10 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
wn = 2; \ wn = 2; \
} }
///////////////////////////////////////////////////////////////////////////////
// Regular steel matmul dispatch
///////////////////////////////////////////////////////////////////////////////
template <bool CHECK_AB> template <bool CHECK_AB>
void steel_matmul_regular_axpby( void steel_matmul_regular_axpby(
const Stream& s, const Stream& s,
@ -296,8 +300,10 @@ void steel_matmul_regular_axpby(
compute_encoder.set_bytes(params, 4); compute_encoder.set_bytes(params, 4);
compute_encoder.set_vector_bytes(batch_shape, 6); if (has_batch) {
compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
}
if (use_out_source) { if (use_out_source) {
int ldc = c.strides()[c.ndim() - 2]; int ldc = c.strides()[c.ndim() - 2];
@ -320,6 +326,10 @@ void steel_matmul_regular_axpby(
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
///////////////////////////////////////////////////////////////////////////////
// Split k steel matmul
///////////////////////////////////////////////////////////////////////////////
template <bool CHECK_AB = true> template <bool CHECK_AB = true>
void steel_gemm_splitk_axpby( void steel_gemm_splitk_axpby(
const Stream& s, const Stream& s,
@ -466,38 +476,9 @@ void steel_gemm_splitk_axpby(
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
inline void steel_gemm_splitk( ///////////////////////////////////////////////////////////////////////////////
const Stream& s, // Split matmul routing
metal::Device& d, ///////////////////////////////////////////////////////////////////////////////
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
return steel_gemm_splitk_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies);
}
template <bool CHECK_AB> template <bool CHECK_AB>
void steel_matmul_axpby( void steel_matmul_axpby(
@ -637,6 +618,10 @@ void steel_matmul_axpby(
/* float beta = */ beta); /* float beta = */ beta);
} }
///////////////////////////////////////////////////////////////////////////////
// GEMV dispatch
///////////////////////////////////////////////////////////////////////////////
template <bool CHECK_AB = true> template <bool CHECK_AB = true>
void gemv_axbpy( void gemv_axbpy(
const Stream& s, const Stream& s,
@ -812,6 +797,10 @@ inline void gemv(
/* Strides B_batch_stride = */ B_batch_stride); /* Strides B_batch_stride = */ B_batch_stride);
} }
///////////////////////////////////////////////////////////////////////////////
// Matmul implementation
///////////////////////////////////////////////////////////////////////////////
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) { if (!issubdtype(out.dtype(), floating)) {
@ -913,6 +902,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/* Strides B_batch_stride = */ std::move(B_batch_stride)); /* Strides B_batch_stride = */ std::move(B_batch_stride));
} }
///////////////////////////////////////////////////////////////////////////////
// AddMM implementation
///////////////////////////////////////////////////////////////////////////////
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3); assert(inputs.size() == 3);
if (!issubdtype(out.dtype(), floating)) { if (!issubdtype(out.dtype(), floating)) {
@ -1043,6 +1036,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* float beta = */ beta_); /* float beta = */ beta_);
} }
///////////////////////////////////////////////////////////////////////////////
// BlockMaskedMM implementation
///////////////////////////////////////////////////////////////////////////////
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel; using namespace mlx::steel;
// assert(inputs.size() == 2); // assert(inputs.size() == 2);
@ -1431,6 +1428,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
///////////////////////////////////////////////////////////////////////////////
// GatherMM implementation
///////////////////////////////////////////////////////////////////////////////
void gather_mm_rhs( void gather_mm_rhs(
const array& a_, const array& a_,
const array& b_, const array& b_,