Some cleanup

This commit is contained in:
Jagrit Digani 2025-06-11 09:43:34 -07:00
parent 4b02d3e738
commit b1d95a3880
2 changed files with 37 additions and 36 deletions

View File

@ -33,8 +33,8 @@ template <
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],

View File

@ -164,6 +164,10 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
wn = 2; \
}
///////////////////////////////////////////////////////////////////////////////
// Regular steel matmul dispatch
///////////////////////////////////////////////////////////////////////////////
template <bool CHECK_AB>
void steel_matmul_regular_axpby(
const Stream& s,
@ -296,8 +300,10 @@ void steel_matmul_regular_axpby(
compute_encoder.set_bytes(params, 4);
if (has_batch) {
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
}
if (use_out_source) {
int ldc = c.strides()[c.ndim() - 2];
@ -320,6 +326,10 @@ void steel_matmul_regular_axpby(
d.add_temporaries(std::move(copies), s.index);
}
///////////////////////////////////////////////////////////////////////////////
// Split k steel matmul
///////////////////////////////////////////////////////////////////////////////
template <bool CHECK_AB = true>
void steel_gemm_splitk_axpby(
const Stream& s,
@ -466,38 +476,9 @@ void steel_gemm_splitk_axpby(
d.add_temporaries(std::move(copies), s.index);
}
inline void steel_gemm_splitk(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
return steel_gemm_splitk_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies);
}
///////////////////////////////////////////////////////////////////////////////
// Split matmul routing
///////////////////////////////////////////////////////////////////////////////
template <bool CHECK_AB>
void steel_matmul_axpby(
@ -637,6 +618,10 @@ void steel_matmul_axpby(
/* float beta = */ beta);
}
///////////////////////////////////////////////////////////////////////////////
// GEMV dispatch
///////////////////////////////////////////////////////////////////////////////
template <bool CHECK_AB = true>
void gemv_axbpy(
const Stream& s,
@ -812,6 +797,10 @@ inline void gemv(
/* Strides B_batch_stride = */ B_batch_stride);
}
///////////////////////////////////////////////////////////////////////////////
// Matmul implementation
///////////////////////////////////////////////////////////////////////////////
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
@ -913,6 +902,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/* Strides B_batch_stride = */ std::move(B_batch_stride));
}
///////////////////////////////////////////////////////////////////////////////
// AddMM implementation
///////////////////////////////////////////////////////////////////////////////
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
if (!issubdtype(out.dtype(), floating)) {
@ -1043,6 +1036,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* float beta = */ beta_);
}
///////////////////////////////////////////////////////////////////////////////
// BlockMaskedMM implementation
///////////////////////////////////////////////////////////////////////////////
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel;
// assert(inputs.size() == 2);
@ -1431,6 +1428,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
d.add_temporaries(std::move(copies), s.index);
}
///////////////////////////////////////////////////////////////////////////////
// GatherMM implementation
///////////////////////////////////////////////////////////////////////////////
void gather_mm_rhs(
const array& a_,
const array& b_,