mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Collection of refactors (#2274)
* Refactor gemv into a function * Refactor splitk step 1 * Refactor split k axpby * Rearrange steel_gemm_regular * Redirect steel_gemm_regular * Add axpby routing to steel_matmul_regular * Refactor AddMM step 1 * Redirect steel_gemm * Update addmm * Comments and format * Some cleanup * Add architecture gen to device * Update no copy condition in normalization to account for axis size 1
This commit is contained in:
parent
c8b4787e4e
commit
fddb6933e1
@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
/* const Stream& s = */ s,
|
||||
/* Device& d = */ d,
|
||||
/* const array& a = */ in_unfolded,
|
||||
/* const array& b = */ wt_transpose,
|
||||
/* array& c = */ out,
|
||||
/* int M = */ implicit_M,
|
||||
/* int N = */ implicit_N,
|
||||
/* int K = */ implicit_K,
|
||||
/* int batch_size_out = */ groups,
|
||||
/* int lda = */ implicit_K * groups,
|
||||
/* int ldb = */ implicit_K,
|
||||
/* int ldd = */ implicit_N * groups,
|
||||
/* bool transpose_a = */ false,
|
||||
/* bool transpose_b = */ true,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ {1},
|
||||
/* Strides batch_strides = */ {0},
|
||||
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
||||
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
||||
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
|
@ -297,6 +297,9 @@ Device::Device() {
|
||||
device_ = load_device();
|
||||
default_library_ = load_default_library(device_);
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
int ag_tens = arch_[arch_.size() - 3] - '0';
|
||||
int ag_ones = arch_[arch_.size() - 2] - '0';
|
||||
arch_gen_ = ag_tens * 10 + ag_ones;
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
case 'p': // phone
|
||||
|
@ -177,6 +177,10 @@ class Device {
|
||||
return arch_;
|
||||
}
|
||||
|
||||
int get_architecture_gen() const {
|
||||
return arch_gen_;
|
||||
}
|
||||
|
||||
void new_queue(int index);
|
||||
|
||||
MTL::CommandQueue* get_queue(Stream stream);
|
||||
@ -268,6 +272,7 @@ class Device {
|
||||
library_kernels_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
int arch_gen_;
|
||||
int max_ops_per_buffer_;
|
||||
int max_mb_per_buffer_;
|
||||
};
|
||||
|
@ -33,8 +33,8 @@ template <
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -6,7 +6,34 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void steel_matmul_regular(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_matmul_regular_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
int64_t C_batch_stride = 0,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f);
|
||||
|
||||
inline void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -21,14 +48,61 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies);
|
||||
int64_t matrix_stride_out) {
|
||||
return steel_matmul_regular_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides batch_strides = */ batch_strides,
|
||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||
/* int64_t matrix_stride_out = */ matrix_stride_out);
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_matmul_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {},
|
||||
Strides C_batch_stride = {},
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f);
|
||||
|
||||
inline void steel_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -45,6 +119,26 @@ void steel_matmul(
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {});
|
||||
Strides B_batch_stride = {}) {
|
||||
return steel_matmul_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
|
Loading…
Reference in New Issue
Block a user