This commit is contained in:
Jagrit Digani 2025-06-13 16:45:40 +12:00 committed by GitHub
commit c47089fb1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 720 additions and 642 deletions

View File

@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
// Perform gemm // Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose}; std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular( return steel_matmul_regular(
s, /* const Stream& s = */ s,
d, /* Device& d = */ d,
/* a = */ in_unfolded, /* const array& a = */ in_unfolded,
/* b = */ wt_transpose, /* const array& b = */ wt_transpose,
/* c = */ out, /* array& c = */ out,
/* M = */ implicit_M, /* int M = */ implicit_M,
/* N = */ implicit_N, /* int N = */ implicit_N,
/* K = */ implicit_K, /* int K = */ implicit_K,
/* batch_size_out = */ groups, /* int batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups, /* int lda = */ implicit_K * groups,
/* b_cols = */ implicit_K, /* int ldb = */ implicit_K,
/* out_cols = */ implicit_N * groups, /* int ldd = */ implicit_N * groups,
/* a_transposed = */ false, /* bool transpose_a = */ false,
/* b_transposed = */ true, /* bool transpose_b = */ true,
/* batch_shape = */ {1}, /* std::vector<array>& copies = */ copies,
/* batch_strides = */ {0}, /* Shape batch_shape = */ {1},
/* A_batch_strides = */ size_t(implicit_K), /* Strides batch_strides = */ {0},
/* B_batch_strides = */ size_t(implicit_N) * implicit_K, /* int64_t A_batch_strides = */ int64_t(implicit_K),
/* matrix_stride_out = */ size_t(implicit_N), /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/*copies = */ copies); /* int64_t matrix_stride_out = */ int64_t(implicit_N));
} }
void implicit_gemm_conv_2D_gpu( void implicit_gemm_conv_2D_gpu(

View File

@ -297,6 +297,9 @@ Device::Device() {
device_ = load_device(); device_ = load_device();
default_library_ = load_default_library(device_); default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
int ag_tens = arch_[arch_.size() - 3] - '0';
int ag_ones = arch_[arch_.size() - 2] - '0';
arch_gen_ = ag_tens * 10 + ag_ones;
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {
case 'p': // phone case 'p': // phone

View File

@ -177,6 +177,10 @@ class Device {
return arch_; return arch_;
} }
int get_architecture_gen() const {
return arch_gen_;
}
void new_queue(int index); void new_queue(int index);
MTL::CommandQueue* get_queue(Stream stream); MTL::CommandQueue* get_queue(Stream stream);
@ -268,6 +272,7 @@ class Device {
library_kernels_; library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr}; const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_; std::string arch_;
int arch_gen_;
int max_ops_per_buffer_; int max_ops_per_buffer_;
int max_mb_per_buffer_; int max_mb_per_buffer_;
}; };

View File

@ -33,8 +33,8 @@ template <
device T* D [[buffer(3)]], device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]], const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,34 @@
namespace mlx::core { namespace mlx::core {
void steel_matmul_regular( template <bool CHECK_AB = true>
void steel_matmul_regular_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
int64_t C_batch_stride = 0,
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul_regular(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
const array& a, const array& a,
@ -21,14 +48,61 @@ void steel_matmul_regular(
int ldd, int ldd,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
std::vector<array>& copies,
Shape batch_shape, Shape batch_shape,
Strides batch_strides, Strides batch_strides,
int64_t A_batch_stride, int64_t A_batch_stride,
int64_t B_batch_stride, int64_t B_batch_stride,
int64_t matrix_stride_out, int64_t matrix_stride_out) {
std::vector<array>& copies); return steel_matmul_regular_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides batch_strides = */ batch_strides,
/* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out);
}
void steel_matmul( template <bool CHECK_AB = true>
void steel_matmul_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {},
Strides C_batch_stride = {},
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
const array& a, const array& a,
@ -45,6 +119,26 @@ void steel_matmul(
std::vector<array>& copies, std::vector<array>& copies,
Shape batch_shape = {}, Shape batch_shape = {},
Strides A_batch_stride = {}, Strides A_batch_stride = {},
Strides B_batch_stride = {}); Strides B_batch_stride = {}) {
return steel_matmul_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
} }
if (no_copy) { if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
} }
if (no_copy) { if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {