Add architecture gen to device

This commit is contained in:
Jagrit Digani 2025-06-11 09:56:01 -07:00
parent b1d95a3880
commit 53fa981caf
3 changed files with 8 additions and 2 deletions

View File

@ -297,6 +297,9 @@ Device::Device() {
device_ = load_device(); device_ = load_device();
default_library_ = load_default_library(device_); default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
int ag_tens = arch_[arch_.size() - 3] - '0';
int ag_ones = arch_[arch_.size() - 2] - '0';
arch_gen_ = ag_tens * 10 + ag_ones;
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {
case 'p': // phone case 'p': // phone

View File

@ -177,6 +177,10 @@ class Device {
return arch_; return arch_;
} }
int get_architecture_gen() const {
return arch_gen_;
}
void new_queue(int index); void new_queue(int index);
MTL::CommandQueue* get_queue(Stream stream); MTL::CommandQueue* get_queue(Stream stream);
@ -268,6 +272,7 @@ class Device {
library_kernels_; library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr}; const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_; std::string arch_;
int arch_gen_;
int max_ops_per_buffer_; int max_ops_per_buffer_;
int max_mb_per_buffer_; int max_mb_per_buffer_;
}; };

View File

@ -503,8 +503,6 @@ void steel_matmul_axpby(
Strides C_batch_stride /* = {} */, Strides C_batch_stride /* = {} */,
float alpha /* = 1.0f */, float alpha /* = 1.0f */,
float beta /* = 0.0f */) { float beta /* = 0.0f */) {
using namespace mlx::steel;
if (batch_shape.empty()) { if (batch_shape.empty()) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions