* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
This commit is contained in:
Awni Hannun
2024-10-25 09:35:33 -07:00
committed by GitHub
parent f70764a162
commit 0eb56d5be0
13 changed files with 246 additions and 14 deletions

View File

@@ -206,6 +206,9 @@ void Device::new_queue(int index) {
"[metal::Device] Failed to make new command queue.");
}
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}
int Device::get_command_buffer_ops(int index) {
@@ -351,7 +354,7 @@ MTL::Library* Device::build_library_(const std::string& source_string) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build metal library from source" << "\n";
msg << "[metal::Device] Unable to build metal library from source\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -593,6 +596,21 @@ MTL::ComputePipelineState* Device::get_kernel(
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
"[Device::set_residency_set] Can only be set once.");
}
if (residency_set == nullptr) {
return;
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
}
}
Device& device(mlx::core::Device) {
static Device metal_device;
return metal_device;