mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 00:31:12 +08:00
Generalize gpu backend (#2138)
* generalize gpu backend * fix no_gpu build * fix no_gpu build * generalize gpu backend
This commit is contained in:
parent
87720a8908
commit
f1606486d2
@ -49,5 +49,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
|||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||||
endif()
|
endif()
|
||||||
|
@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble)
|
|||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
|
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/available.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
9
mlx/backend/gpu/available.h
Normal file
9
mlx/backend/gpu/available.h
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::gpu
|
@ -8,14 +8,11 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
void new_stream(Stream stream);
|
void new_stream(Stream stream);
|
||||||
|
|
||||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
|
||||||
|
|
||||||
void eval(array& arr);
|
void eval(array& arr);
|
||||||
void finalize(Stream s);
|
void finalize(Stream s);
|
||||||
void synchronize(Stream s);
|
void synchronize(Stream s);
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::gpu
|
@ -93,6 +93,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include "mlx/backend/metal/allocator.h"
|
#include "mlx/backend/metal/allocator.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
|
||||||
#include "mlx/backend/metal/resident.h"
|
#include "mlx/backend/metal/resident.h"
|
||||||
#include "mlx/memory.h"
|
#include "mlx/memory.h"
|
||||||
|
|
||||||
|
@ -4,15 +4,12 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include <sys/sysctl.h>
|
|
||||||
|
|
||||||
#define NS_PRIVATE_IMPLEMENTATION
|
#define NS_PRIVATE_IMPLEMENTATION
|
||||||
#define CA_PRIVATE_IMPLEMENTATION
|
#define CA_PRIVATE_IMPLEMENTATION
|
||||||
#define MTL_PRIVATE_IMPLEMENTATION
|
#define MTL_PRIVATE_IMPLEMENTATION
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -772,42 +769,4 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
|||||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream stream) {
|
|
||||||
if (stream.device == mlx::core::Device::gpu) {
|
|
||||||
device(stream.device).new_queue(stream.index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
|
||||||
device_info() {
|
|
||||||
auto init_device_info = []()
|
|
||||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
auto raw_device = device(default_device()).mtl_device();
|
|
||||||
auto name = std::string(raw_device->name()->utf8String());
|
|
||||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
|
||||||
|
|
||||||
size_t memsize = 0;
|
|
||||||
size_t length = sizeof(memsize);
|
|
||||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
|
||||||
|
|
||||||
size_t rsrc_limit = 0;
|
|
||||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
|
||||||
if (rsrc_limit == 0) {
|
|
||||||
rsrc_limit = 499000;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
{"device_name", name},
|
|
||||||
{"architecture", arch},
|
|
||||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
|
||||||
{"max_recommended_working_set_size",
|
|
||||||
raw_device->recommendedMaxWorkingSetSize()},
|
|
||||||
{"memory_size", memsize},
|
|
||||||
{"resource_limit", rsrc_limit}};
|
|
||||||
};
|
|
||||||
static auto device_info_ = init_device_info();
|
|
||||||
return device_info_;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -266,4 +266,6 @@ class Device {
|
|||||||
|
|
||||||
Device& device(mlx::core::Device);
|
Device& device(mlx::core::Device);
|
||||||
|
|
||||||
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
102
mlx/backend/metal/eval.cpp
Normal file
102
mlx/backend/metal/eval.cpp
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mlx/backend/gpu/available.h"
|
||||||
|
#include "mlx/backend/gpu/eval.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void new_stream(Stream stream) {
|
||||||
|
if (stream.device == mlx::core::Device::gpu) {
|
||||||
|
metal::device(stream.device).new_queue(stream.index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||||
|
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[METAL] Command buffer execution failed: "
|
||||||
|
<< cbuf->error()->localizedDescription()->utf8String();
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void eval(array& arr) {
|
||||||
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
auto s = arr.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
|
||||||
|
auto outputs = arr.outputs();
|
||||||
|
{
|
||||||
|
// If the array is a tracer hold a reference
|
||||||
|
// to its inputs so they don't get donated
|
||||||
|
std::vector<array> inputs;
|
||||||
|
if (arr.is_tracer()) {
|
||||||
|
inputs = arr.inputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||||
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||||
|
}
|
||||||
|
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||||
|
for (auto& in : arr.inputs()) {
|
||||||
|
buffers.insert(in.data_shared_ptr());
|
||||||
|
}
|
||||||
|
for (auto& s : arr.siblings()) {
|
||||||
|
buffers.insert(s.data_shared_ptr());
|
||||||
|
}
|
||||||
|
// Remove the output if it was donated to by an input
|
||||||
|
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||||
|
buffers.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (d.command_buffer_needs_commit(s.index)) {
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
scheduler::notify_new_task(s);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||||
|
scheduler::notify_task_completion(s);
|
||||||
|
check_error(cbuf);
|
||||||
|
});
|
||||||
|
d.commit_command_buffer(s.index);
|
||||||
|
d.get_command_buffer(s.index);
|
||||||
|
} else {
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||||
|
check_error(cbuf);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize(Stream s) {
|
||||||
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto cb = d.get_command_buffer(s.index);
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||||
|
d.commit_command_buffer(s.index);
|
||||||
|
d.get_command_buffer(s.index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void synchronize(Stream s) {
|
||||||
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto cb = d.get_command_buffer(s.index);
|
||||||
|
cb->retain();
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
d.commit_command_buffer(s.index);
|
||||||
|
cb->waitUntilCompleted();
|
||||||
|
check_error(cb);
|
||||||
|
cb->release();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::gpu
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include <sys/sysctl.h>
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
@ -13,85 +13,6 @@ bool is_available() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
|
||||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[METAL] Command buffer execution failed: "
|
|
||||||
<< cbuf->error()->localizedDescription()->utf8String();
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void eval(array& arr) {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
auto s = arr.primitive().stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
|
|
||||||
auto outputs = arr.outputs();
|
|
||||||
{
|
|
||||||
// If the array is a tracer hold a reference
|
|
||||||
// to its inputs so they don't get donated
|
|
||||||
std::vector<array> inputs;
|
|
||||||
if (arr.is_tracer()) {
|
|
||||||
inputs = arr.inputs();
|
|
||||||
}
|
|
||||||
|
|
||||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
|
||||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
|
||||||
}
|
|
||||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
|
||||||
for (auto& in : arr.inputs()) {
|
|
||||||
buffers.insert(in.data_shared_ptr());
|
|
||||||
}
|
|
||||||
for (auto& s : arr.siblings()) {
|
|
||||||
buffers.insert(s.data_shared_ptr());
|
|
||||||
}
|
|
||||||
// Remove the output if it was donated to by an input
|
|
||||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
|
||||||
buffers.erase(it);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (d.command_buffer_needs_commit(s.index)) {
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
scheduler::notify_new_task(s);
|
|
||||||
command_buffer->addCompletedHandler(
|
|
||||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
|
||||||
scheduler::notify_task_completion(s);
|
|
||||||
check_error(cbuf);
|
|
||||||
});
|
|
||||||
d.commit_command_buffer(s.index);
|
|
||||||
d.get_command_buffer(s.index);
|
|
||||||
} else {
|
|
||||||
command_buffer->addCompletedHandler(
|
|
||||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
|
||||||
check_error(cbuf);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void finalize(Stream s) {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto cb = d.get_command_buffer(s.index);
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
|
||||||
d.commit_command_buffer(s.index);
|
|
||||||
d.get_command_buffer(s.index);
|
|
||||||
}
|
|
||||||
|
|
||||||
void synchronize(Stream s) {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto cb = d.get_command_buffer(s.index);
|
|
||||||
cb->retain();
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
d.commit_command_buffer(s.index);
|
|
||||||
cb->waitUntilCompleted();
|
|
||||||
check_error(cb);
|
|
||||||
cb->release();
|
|
||||||
}
|
|
||||||
|
|
||||||
void start_capture(std::string path, id object) {
|
void start_capture(std::string path, id object) {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
@ -128,4 +49,36 @@ void stop_capture() {
|
|||||||
manager->stopCapture();
|
manager->stopCapture();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||||
|
device_info() {
|
||||||
|
auto init_device_info = []()
|
||||||
|
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
|
auto raw_device = device(default_device()).mtl_device();
|
||||||
|
auto name = std::string(raw_device->name()->utf8String());
|
||||||
|
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||||
|
|
||||||
|
size_t memsize = 0;
|
||||||
|
size_t length = sizeof(memsize);
|
||||||
|
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||||
|
|
||||||
|
size_t rsrc_limit = 0;
|
||||||
|
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||||
|
if (rsrc_limit == 0) {
|
||||||
|
rsrc_limit = 499000;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
{"device_name", name},
|
||||||
|
{"architecture", arch},
|
||||||
|
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||||
|
{"max_recommended_working_set_size",
|
||||||
|
raw_device->recommendedMaxWorkingSetSize()},
|
||||||
|
{"memory_size", memsize},
|
||||||
|
{"resource_limit", rsrc_limit}};
|
||||||
|
};
|
||||||
|
static auto device_info_ = init_device_info();
|
||||||
|
return device_info_;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
/* Check if the Metal backend is available. */
|
/* Check if the Metal backend is available. */
|
||||||
|
22
mlx/backend/metal/no_metal.cpp
Normal file
22
mlx/backend/metal/no_metal.cpp
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void start_capture(std::string) {}
|
||||||
|
void stop_capture() {}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||||
|
device_info() {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[metal::device_info] Cannot get device info without metal backend");
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/resident.h"
|
#include "mlx/backend/metal/resident.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)
|
||||||
|
11
mlx/backend/no_cpu/available.cpp
Normal file
11
mlx/backend/no_cpu/available.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/available.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
@ -3,5 +3,5 @@ target_sources(
|
|||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)
|
@ -6,9 +6,9 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
#include "mlx/backend/no_metal/apple_memory.h"
|
#include "mlx/backend/no_gpu/apple_memory.h"
|
||||||
#elif defined(__linux__)
|
#elif defined(__linux__)
|
||||||
#include "mlx/backend/no_metal/linux_memory.h"
|
#include "mlx/backend/no_gpu/linux_memory.h"
|
||||||
#else
|
#else
|
||||||
size_t get_memory_size() {
|
size_t get_memory_size() {
|
||||||
return 0;
|
return 0;
|
28
mlx/backend/no_gpu/eval.cpp
Normal file
28
mlx/backend/no_gpu/eval.cpp
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "mlx/backend/gpu/available.h"
|
||||||
|
#include "mlx/backend/gpu/eval.h"
|
||||||
|
|
||||||
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void new_stream(Stream) {}
|
||||||
|
|
||||||
|
void eval(array&) {
|
||||||
|
throw std::runtime_error("[gpu::eval] GPU backend is not available");
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize(Stream) {
|
||||||
|
throw std::runtime_error("[gpu::finalize] GPU backend is not available");
|
||||||
|
}
|
||||||
|
|
||||||
|
void synchronize(Stream) {
|
||||||
|
throw std::runtime_error("[gpu::synchronize] GPU backend is not available");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::gpu
|
@ -1,43 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <stdexcept>
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
|
||||||
namespace mlx::core::metal {
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void new_stream(Stream) {}
|
|
||||||
|
|
||||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void eval(array&) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[metal::eval] Cannot eval on GPU without metal backend");
|
|
||||||
}
|
|
||||||
|
|
||||||
void finalize(Stream) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[metal::finalize] Cannot finalize GPU without metal backend");
|
|
||||||
}
|
|
||||||
|
|
||||||
void synchronize(Stream) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[metal::synchronize] Cannot synchronize GPU without metal backend");
|
|
||||||
}
|
|
||||||
|
|
||||||
void start_capture(std::string) {}
|
|
||||||
void stop_capture() {}
|
|
||||||
|
|
||||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
|
||||||
device_info() {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[metal::device_info] Cannot get device info without metal backend");
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
|
@ -1,13 +1,15 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/available.h"
|
||||||
|
#include "mlx/backend/gpu/available.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
Device& mutable_default_device() {
|
Device& mutable_default_device() {
|
||||||
static Device default_device{
|
static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
|
||||||
metal::is_available() ? Device::gpu : Device::cpu};
|
|
||||||
return default_device;
|
return default_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -16,7 +18,7 @@ const Device& default_device() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void set_default_device(const Device& d) {
|
void set_default_device(const Device& d) {
|
||||||
if (!metal::is_available() && d == Device::gpu) {
|
if (!gpu::is_available() && d == Device::gpu) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[set_default_device] Cannot set gpu device without gpu backend.");
|
"[set_default_device] Cannot set gpu device without gpu backend.");
|
||||||
}
|
}
|
||||||
@ -31,4 +33,15 @@ bool operator!=(const Device& lhs, const Device& rhs) {
|
|||||||
return !(lhs == rhs);
|
return !(lhs == rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_available(const Device& d) {
|
||||||
|
switch (d.type) {
|
||||||
|
case Device::cpu:
|
||||||
|
return cpu::is_available();
|
||||||
|
case Device::gpu:
|
||||||
|
return gpu::is_available();
|
||||||
|
}
|
||||||
|
// appease compiler
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -26,4 +26,6 @@ void set_default_device(const Device& d);
|
|||||||
bool operator==(const Device& lhs, const Device& rhs);
|
bool operator==(const Device& lhs, const Device& rhs);
|
||||||
bool operator!=(const Device& lhs, const Device& rhs);
|
bool operator!=(const Device& lhs, const Device& rhs);
|
||||||
|
|
||||||
|
bool is_available(const Device& d);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/gpu/available.h"
|
||||||
|
#include "mlx/backend/gpu/eval.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
Stream default_stream(Device d) {
|
Stream default_stream(Device d) {
|
||||||
if (!metal::is_available() && d == Device::gpu) {
|
if (!gpu::is_available() && d == Device::gpu) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[default_stream] Cannot get gpu stream without gpu backend.");
|
"[default_stream] Cannot get gpu stream without gpu backend.");
|
||||||
}
|
}
|
||||||
@ -14,7 +15,7 @@ Stream default_stream(Device d) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void set_default_stream(Stream s) {
|
void set_default_stream(Stream s) {
|
||||||
if (!metal::is_available() && s.device == Device::gpu) {
|
if (!gpu::is_available() && s.device == Device::gpu) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[set_default_stream] Cannot set gpu stream without gpu backend.");
|
"[set_default_stream] Cannot set gpu stream without gpu backend.");
|
||||||
}
|
}
|
||||||
@ -26,7 +27,7 @@ Stream get_stream(int index) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Stream new_stream(Device d) {
|
Stream new_stream(Device d) {
|
||||||
if (!metal::is_available() && d == Device::gpu) {
|
if (!gpu::is_available() && d == Device::gpu) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[new_stream] Cannot make gpu stream without gpu backend.");
|
"[new_stream] Cannot make gpu stream without gpu backend.");
|
||||||
}
|
}
|
||||||
@ -44,7 +45,7 @@ void synchronize(Stream s) {
|
|||||||
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
||||||
f.wait();
|
f.wait();
|
||||||
} else {
|
} else {
|
||||||
metal::synchronize(s);
|
gpu::synchronize(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,8 +8,7 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/gpu/eval.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
@ -67,7 +66,7 @@ struct StreamThread {
|
|||||||
class Scheduler {
|
class Scheduler {
|
||||||
public:
|
public:
|
||||||
Scheduler() : n_active_tasks_(0) {
|
Scheduler() : n_active_tasks_(0) {
|
||||||
if (metal::is_available()) {
|
if (is_available(Device::gpu)) {
|
||||||
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
||||||
}
|
}
|
||||||
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
||||||
@ -83,7 +82,7 @@ class Scheduler {
|
|||||||
streams_.emplace_back(streams_.size(), d);
|
streams_.emplace_back(streams_.size(), d);
|
||||||
if (d == Device::gpu) {
|
if (d == Device::gpu) {
|
||||||
threads_.push_back(nullptr);
|
threads_.push_back(nullptr);
|
||||||
metal::new_stream(streams_.back());
|
gpu::new_stream(streams_.back());
|
||||||
} else {
|
} else {
|
||||||
threads_.push_back(new StreamThread{});
|
threads_.push_back(new StreamThread{});
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/backend/cpu/eval.h"
|
#include "mlx/backend/cpu/eval.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/gpu/eval.h"
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
#include "mlx/memory.h"
|
#include "mlx/memory.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
@ -218,7 +218,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (arr.primitive().device() == Device::gpu) {
|
if (arr.primitive().device() == Device::gpu) {
|
||||||
metal::eval(arr);
|
gpu::eval(arr);
|
||||||
} else {
|
} else {
|
||||||
cpu::eval(arr);
|
cpu::eval(arr);
|
||||||
}
|
}
|
||||||
@ -229,7 +229,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
// Commit any open streams
|
// Commit any open streams
|
||||||
for (auto& [_, e] : events) {
|
for (auto& [_, e] : events) {
|
||||||
if (e.stream().device == Device::gpu) {
|
if (e.stream().device == Device::gpu) {
|
||||||
metal::finalize(e.stream());
|
gpu::finalize(e.stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
@ -267,7 +267,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
auto s = e.stream();
|
auto s = e.stream();
|
||||||
e.signal(s);
|
e.signal(s);
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
metal::finalize(s);
|
gpu::finalize(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user