diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index abf46a7d5..465954d6f 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -49,5 +49,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 152f33b17..96b3f1313 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble) target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp diff --git a/mlx/backend/cpu/available.cpp b/mlx/backend/cpu/available.cpp new file mode 100644 index 000000000..0449d49b9 --- /dev/null +++ b/mlx/backend/cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return true; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/available.h b/mlx/backend/cpu/available.h new file mode 100644 index 000000000..1df95def2 --- /dev/null +++ b/mlx/backend/cpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cpu { + +bool is_available(); + +} // namespace mlx::core::cpu diff --git a/mlx/backend/gpu/available.h b/mlx/backend/gpu/available.h new file mode 100644 index 000000000..476c7acf2 --- /dev/null +++ b/mlx/backend/gpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::gpu { + +bool is_available(); + +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/gpu/eval.h similarity index 63% rename from mlx/backend/metal/metal_impl.h rename to mlx/backend/gpu/eval.h index 9ca8d2f80..f646c2ec9 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/gpu/eval.h @@ -8,14 +8,11 @@ #include "mlx/array.h" #include "mlx/stream.h" -namespace mlx::core::metal { +namespace mlx::core::gpu { void new_stream(Stream stream); - -std::unique_ptr> new_scoped_memory_pool(); - void eval(array& arr); void finalize(Stream s); void synchronize(Stream s); -} // namespace mlx::core::metal +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 332c560f8..d0c872451 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -93,6 +93,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0a69dd261..5d8bd90d5 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,7 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/resident.h" #include "mlx/memory.h" diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index cb851b57e..ebc3cc77f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -4,15 +4,12 @@ #include #include -#include - #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" @@ -772,42 +769,4 @@ std::unique_ptr> new_scoped_memory_pool() { NS::AutoreleasePool::alloc()->init(), dtor); } -void new_stream(Stream stream) { - if (stream.device == mlx::core::Device::gpu) { - device(stream.device).new_queue(stream.index); - } -} - -const std::unordered_map>& -device_info() { - auto init_device_info = []() - -> std::unordered_map> { - auto pool = new_scoped_memory_pool(); - auto raw_device = device(default_device()).mtl_device(); - auto name = std::string(raw_device->name()->utf8String()); - auto arch = std::string(raw_device->architecture()->name()->utf8String()); - - size_t memsize = 0; - size_t length = sizeof(memsize); - sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); - - size_t rsrc_limit = 0; - sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); - if (rsrc_limit == 0) { - rsrc_limit = 499000; - } - - return { - {"device_name", name}, - {"architecture", arch}, - {"max_buffer_length", raw_device->maxBufferLength()}, - {"max_recommended_working_set_size", - raw_device->recommendedMaxWorkingSetSize()}, - {"memory_size", memsize}, - {"resource_limit", rsrc_limit}}; - }; - static auto device_info_ = init_device_info(); - return device_info_; -} - } // namespace mlx::core::metal diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index d60635e39..26c9a0a28 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -266,4 +266,6 @@ class Device { Device& device(mlx::core::Device); +std::unique_ptr> new_scoped_memory_pool(); + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp new file mode 100644 index 000000000..49783200a --- /dev/null +++ b/mlx/backend/metal/eval.cpp @@ -0,0 +1,102 @@ +// Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream stream) { + if (stream.device == mlx::core::Device::gpu) { + metal::device(stream.device).new_queue(stream.index); + } +} + +inline void check_error(MTL::CommandBuffer* cbuf) { + if (cbuf->status() == MTL::CommandBufferStatusError) { + std::ostringstream msg; + msg << "[METAL] Command buffer execution failed: " + << cbuf->error()->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } +} + +void eval(array& arr) { + auto pool = metal::new_scoped_memory_pool(); + auto s = arr.primitive().stream(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + + debug_set_primitive_buffer_label(command_buffer, arr.primitive()); + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + + if (d.command_buffer_needs_commit(s.index)) { + d.end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + scheduler::notify_task_completion(s); + check_error(cbuf); + }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + check_error(cbuf); + }); + } +} + +void finalize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + d.end_encoding(s.index); + cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); +} + +void synchronize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + cb->retain(); + d.end_encoding(s.index); + d.commit_command_buffer(s.index); + cb->waitUntilCompleted(); + check_error(cb); + cb->release(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 246d6bcc5..eb7f1b58a 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -2,7 +2,6 @@ #include "mlx/event.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" namespace mlx::core { diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index e784d34ae..d4a88d983 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" #include "mlx/utils.h" diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index a9a1bc4f6..888207322 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,11 +1,11 @@ // Copyright © 2023-2024 Apple Inc. #include +#include + #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" -#include "mlx/scheduler.h" -#include "mlx/utils.h" namespace mlx::core::metal { @@ -13,85 +13,6 @@ bool is_available() { return true; } -inline void check_error(MTL::CommandBuffer* cbuf) { - if (cbuf->status() == MTL::CommandBufferStatusError) { - std::ostringstream msg; - msg << "[METAL] Command buffer execution failed: " - << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); - } -} - -void eval(array& arr) { - auto pool = new_scoped_memory_pool(); - auto s = arr.primitive().stream(); - auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - - auto outputs = arr.outputs(); - { - // If the array is a tracer hold a reference - // to its inputs so they don't get donated - std::vector inputs; - if (arr.is_tracer()) { - inputs = arr.inputs(); - } - - debug_set_primitive_buffer_label(command_buffer, arr.primitive()); - arr.primitive().eval_gpu(arr.inputs(), outputs); - } - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - - if (d.command_buffer_needs_commit(s.index)) { - d.end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); - }); - } -} - -void finalize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - d.end_encoding(s.index); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); -} - -void synchronize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - cb->retain(); - d.end_encoding(s.index); - d.commit_command_buffer(s.index); - cb->waitUntilCompleted(); - check_error(cb); - cb->release(); -} - void start_capture(std::string path, id object) { auto pool = new_scoped_memory_pool(); @@ -128,4 +49,36 @@ void stop_capture() { manager->stopCapture(); } +const std::unordered_map>& +device_info() { + auto init_device_info = []() + -> std::unordered_map> { + auto pool = new_scoped_memory_pool(); + auto raw_device = device(default_device()).mtl_device(); + auto name = std::string(raw_device->name()->utf8String()); + auto arch = std::string(raw_device->architecture()->name()->utf8String()); + + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + + size_t rsrc_limit = 0; + sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); + if (rsrc_limit == 0) { + rsrc_limit = 499000; + } + + return { + {"device_name", name}, + {"architecture", arch}, + {"max_buffer_length", raw_device->maxBufferLength()}, + {"max_recommended_working_set_size", + raw_device->recommendedMaxWorkingSetSize()}, + {"memory_size", memsize}, + {"resource_limit", rsrc_limit}}; + }; + static auto device_info_ = init_device_info(); + return device_info_; +} + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d162007d1..af2995b63 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -2,11 +2,10 @@ #pragma once +#include #include #include -#include "mlx/array.h" - namespace mlx::core::metal { /* Check if the Metal backend is available. */ diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp new file mode 100644 index 000000000..d8608d1f8 --- /dev/null +++ b/mlx/backend/metal/no_metal.cpp @@ -0,0 +1,18 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/metal.h" + +namespace mlx::core::metal { + +void start_capture(std::string) {} +void stop_capture() {} + +const std::unordered_map>& +device_info() { + throw std::runtime_error( + "[metal::device_info] Cannot get device info without metal backend"); +}; + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 0a9e1b861..798824c2f 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/resident.h" -#include "mlx/backend/metal/metal_impl.h" namespace mlx::core::metal { diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index e1524ec63..2e6960829 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -1,6 +1,7 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) diff --git a/mlx/backend/no_cpu/available.cpp b/mlx/backend/no_cpu/available.cpp new file mode 100644 index 000000000..04c1bac8e --- /dev/null +++ b/mlx/backend/no_cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return false; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_gpu/CMakeLists.txt similarity index 100% rename from mlx/backend/no_metal/CMakeLists.txt rename to mlx/backend/no_gpu/CMakeLists.txt diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp similarity index 100% rename from mlx/backend/no_metal/allocator.cpp rename to mlx/backend/no_gpu/allocator.cpp diff --git a/mlx/backend/no_metal/apple_memory.h b/mlx/backend/no_gpu/apple_memory.h similarity index 100% rename from mlx/backend/no_metal/apple_memory.h rename to mlx/backend/no_gpu/apple_memory.h diff --git a/mlx/backend/no_gpu/eval.cpp b/mlx/backend/no_gpu/eval.cpp new file mode 100644 index 000000000..8bff86a98 --- /dev/null +++ b/mlx/backend/no_gpu/eval.cpp @@ -0,0 +1,28 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" + +namespace mlx::core::gpu { + +bool is_available() { + return false; +} + +void new_stream(Stream) {} + +void eval(array&) { + throw std::runtime_error("[gpu::eval] GPU backend is not available"); +} + +void finalize(Stream) { + throw std::runtime_error("[gpu::finalize] GPU backend is not available"); +} + +void synchronize(Stream) { + throw std::runtime_error("[gpu::synchronize] GPU backend is not available"); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_gpu/event.cpp similarity index 100% rename from mlx/backend/no_metal/event.cpp rename to mlx/backend/no_gpu/event.cpp diff --git a/mlx/backend/no_metal/fence.cpp b/mlx/backend/no_gpu/fence.cpp similarity index 100% rename from mlx/backend/no_metal/fence.cpp rename to mlx/backend/no_gpu/fence.cpp diff --git a/mlx/backend/no_metal/linux_memory.h b/mlx/backend/no_gpu/linux_memory.h similarity index 100% rename from mlx/backend/no_metal/linux_memory.h rename to mlx/backend/no_gpu/linux_memory.h diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp similarity index 100% rename from mlx/backend/no_metal/primitives.cpp rename to mlx/backend/no_gpu/primitives.cpp diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp deleted file mode 100644 index ef9af8800..000000000 --- a/mlx/backend/no_metal/metal.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" -namespace mlx::core::metal { - -bool is_available() { - return false; -} - -void new_stream(Stream) {} - -std::unique_ptr> new_scoped_memory_pool() { - return nullptr; -} - -void eval(array&) { - throw std::runtime_error( - "[metal::eval] Cannot eval on GPU without metal backend"); -} - -void finalize(Stream) { - throw std::runtime_error( - "[metal::finalize] Cannot finalize GPU without metal backend"); -} - -void synchronize(Stream) { - throw std::runtime_error( - "[metal::synchronize] Cannot synchronize GPU without metal backend"); -} - -void start_capture(std::string) {} -void stop_capture() {} - -const std::unordered_map>& -device_info() { - throw std::runtime_error( - "[metal::device_info] Cannot get device info without metal backend"); -}; - -} // namespace mlx::core::metal diff --git a/mlx/device.cpp b/mlx/device.cpp index 20d8675d8..5e9e7a430 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -1,13 +1,15 @@ // Copyright © 2023 Apple Inc. +#include + +#include "mlx/backend/cpu/available.h" +#include "mlx/backend/gpu/available.h" #include "mlx/device.h" -#include "mlx/backend/metal/metal.h" namespace mlx::core { Device& mutable_default_device() { - static Device default_device{ - metal::is_available() ? Device::gpu : Device::cpu}; + static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; return default_device; } @@ -16,7 +18,7 @@ const Device& default_device() { } void set_default_device(const Device& d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } @@ -31,4 +33,13 @@ bool operator!=(const Device& lhs, const Device& rhs) { return !(lhs == rhs); } +bool is_available(const Device& d) { + switch (d.type) { + case Device::cpu: + return cpu::is_available(); + case Device::gpu: + return gpu::is_available(); + } +} + } // namespace mlx::core diff --git a/mlx/device.h b/mlx/device.h index a11e40e9d..80c624c1c 100644 --- a/mlx/device.h +++ b/mlx/device.h @@ -26,4 +26,6 @@ void set_default_device(const Device& d); bool operator==(const Device& lhs, const Device& rhs); bool operator!=(const Device& lhs, const Device& rhs); +bool is_available(const Device& d); + } // namespace mlx::core diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 7bd128c10..b19f6434a 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -1,12 +1,13 @@ // Copyright © 2023 Apple Inc. #include "mlx/scheduler.h" -#include "mlx/backend/metal/metal.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" namespace mlx::core { Stream default_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[default_stream] Cannot get gpu stream without gpu backend."); } @@ -14,7 +15,7 @@ Stream default_stream(Device d) { } void set_default_stream(Stream s) { - if (!metal::is_available() && s.device == Device::gpu) { + if (!gpu::is_available() && s.device == Device::gpu) { throw std::invalid_argument( "[set_default_stream] Cannot set gpu stream without gpu backend."); } @@ -26,7 +27,7 @@ Stream get_stream(int index) { } Stream new_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[new_stream] Cannot make gpu stream without gpu backend."); } @@ -44,7 +45,7 @@ void synchronize(Stream s) { scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); f.wait(); } else { - metal::synchronize(s); + gpu::synchronize(s); } } diff --git a/mlx/scheduler.h b/mlx/scheduler.h index b2c6b842b..877fdd5f6 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -8,8 +8,7 @@ #include #include -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/device.h" #include "mlx/stream.h" @@ -67,7 +66,7 @@ struct StreamThread { class Scheduler { public: Scheduler() : n_active_tasks_(0) { - if (metal::is_available()) { + if (is_available(Device::gpu)) { default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); } default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); @@ -83,7 +82,7 @@ class Scheduler { streams_.emplace_back(streams_.size(), d); if (d == Device::gpu) { threads_.push_back(nullptr); - metal::new_stream(streams_.back()); + gpu::new_stream(streams_.back()); } else { threads_.push_back(new StreamThread{}); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index f9a5de031..2d9942eda 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -10,7 +10,7 @@ #include #include "mlx/backend/cpu/eval.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/fence.h" #include "mlx/memory.h" #include "mlx/ops.h" @@ -218,7 +218,7 @@ array eval_impl(std::vector outputs, bool async) { } if (arr.primitive().device() == Device::gpu) { - metal::eval(arr); + gpu::eval(arr); } else { cpu::eval(arr); } @@ -229,7 +229,7 @@ array eval_impl(std::vector outputs, bool async) { // Commit any open streams for (auto& [_, e] : events) { if (e.stream().device == Device::gpu) { - metal::finalize(e.stream()); + gpu::finalize(e.stream()); } } scheduler::wait_for_one(); @@ -267,7 +267,7 @@ array eval_impl(std::vector outputs, bool async) { auto s = e.stream(); e.signal(s); if (s.device == Device::gpu) { - metal::finalize(s); + gpu::finalize(s); } }