diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 425274361..88835eb75 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -297,6 +297,9 @@ Device::Device() { device_ = load_device(); default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); + int ag_tens = arch_[arch_.size() - 3] - '0'; + int ag_ones = arch_[arch_.size() - 2] - '0'; + arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { case 'p': // phone diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5bfcc6649..f87a8c48b 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -177,6 +177,10 @@ class Device { return arch_; } + int get_architecture_gen() const { + return arch_gen_; + } + void new_queue(int index); MTL::CommandQueue* get_queue(Stream stream); @@ -268,6 +272,7 @@ class Device { library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; + int arch_gen_; int max_ops_per_buffer_; int max_mb_per_buffer_; }; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 2a7a8dd94..be7f3e2f8 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -503,8 +503,6 @@ void steel_matmul_axpby( Strides C_batch_stride /* = {} */, float alpha /* = 1.0f */, float beta /* = 0.0f */) { - using namespace mlx::steel; - if (batch_shape.empty()) { ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions