mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Some cleanup
This commit is contained in:
parent
4b02d3e738
commit
b1d95a3880
@ -33,8 +33,8 @@ template <
|
|||||||
device T* D [[buffer(3)]],
|
device T* D [[buffer(3)]],
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||||
const constant int64_t* batch_strides [[buffer(7)]],
|
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
@ -164,6 +164,10 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|||||||
wn = 2; \
|
wn = 2; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Regular steel matmul dispatch
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <bool CHECK_AB>
|
template <bool CHECK_AB>
|
||||||
void steel_matmul_regular_axpby(
|
void steel_matmul_regular_axpby(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@ -296,8 +300,10 @@ void steel_matmul_regular_axpby(
|
|||||||
|
|
||||||
compute_encoder.set_bytes(params, 4);
|
compute_encoder.set_bytes(params, 4);
|
||||||
|
|
||||||
|
if (has_batch) {
|
||||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||||
|
}
|
||||||
|
|
||||||
if (use_out_source) {
|
if (use_out_source) {
|
||||||
int ldc = c.strides()[c.ndim() - 2];
|
int ldc = c.strides()[c.ndim() - 2];
|
||||||
@ -320,6 +326,10 @@ void steel_matmul_regular_axpby(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Split k steel matmul
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <bool CHECK_AB = true>
|
template <bool CHECK_AB = true>
|
||||||
void steel_gemm_splitk_axpby(
|
void steel_gemm_splitk_axpby(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@ -466,38 +476,9 @@ void steel_gemm_splitk_axpby(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void steel_gemm_splitk(
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
const Stream& s,
|
// Split matmul routing
|
||||||
metal::Device& d,
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K,
|
|
||||||
int batch_size_out,
|
|
||||||
int lda,
|
|
||||||
int ldb,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
std::vector<array>& copies) {
|
|
||||||
return steel_gemm_splitk_axpby<false>(
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const array& a = */ a,
|
|
||||||
/* const array& b = */ b,
|
|
||||||
/* const array& c = */ b,
|
|
||||||
/* array& out = */ out,
|
|
||||||
/* int M = */ M,
|
|
||||||
/* int N = */ N,
|
|
||||||
/* int K = */ K,
|
|
||||||
/* int batch_size_out = */ batch_size_out,
|
|
||||||
/* int lda = */ lda,
|
|
||||||
/* int ldb = */ ldb,
|
|
||||||
/* bool transpose_a = */ transpose_a,
|
|
||||||
/* bool transpose_b = */ transpose_b,
|
|
||||||
/* std::vector<array>& copies = */ copies);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool CHECK_AB>
|
template <bool CHECK_AB>
|
||||||
void steel_matmul_axpby(
|
void steel_matmul_axpby(
|
||||||
@ -637,6 +618,10 @@ void steel_matmul_axpby(
|
|||||||
/* float beta = */ beta);
|
/* float beta = */ beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMV dispatch
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <bool CHECK_AB = true>
|
template <bool CHECK_AB = true>
|
||||||
void gemv_axbpy(
|
void gemv_axbpy(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@ -812,6 +797,10 @@ inline void gemv(
|
|||||||
/* Strides B_batch_stride = */ B_batch_stride);
|
/* Strides B_batch_stride = */ B_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Matmul implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
if (!issubdtype(out.dtype(), floating)) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
@ -913,6 +902,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* Strides B_batch_stride = */ std::move(B_batch_stride));
|
/* Strides B_batch_stride = */ std::move(B_batch_stride));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// AddMM implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 3);
|
assert(inputs.size() == 3);
|
||||||
if (!issubdtype(out.dtype(), floating)) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
@ -1043,6 +1036,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* float beta = */ beta_);
|
/* float beta = */ beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// BlockMaskedMM implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
// assert(inputs.size() == 2);
|
// assert(inputs.size() == 2);
|
||||||
@ -1431,6 +1428,10 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GatherMM implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
void gather_mm_rhs(
|
void gather_mm_rhs(
|
||||||
const array& a_,
|
const array& a_,
|
||||||
const array& b_,
|
const array& b_,
|
||||||
|
Loading…
Reference in New Issue
Block a user