mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 12:01:17 +08:00
Add io device and cpu::make_task
This commit is contained in:
parent
9814a2ae12
commit
be36f136de
@ -55,6 +55,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cpu_impl.cpp
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
48
mlx/backend/common/cpu_impl.cpp
Normal file
48
mlx/backend/common/cpu_impl.cpp
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/cpu_impl.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
std::function<void()> make_task(array arr, bool signal) {
|
||||||
|
return [arr = std::move(arr), signal]() mutable {
|
||||||
|
auto stream = arr.primitive().stream();
|
||||||
|
|
||||||
|
// Wait on inputs coming from different streams/devices.
|
||||||
|
for (auto& input : arr.inputs()) {
|
||||||
|
if (input.event().valid() &&
|
||||||
|
input.event().stream() != arr.primitive().stream()) {
|
||||||
|
input.event().wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task computation actually starting.
|
||||||
|
scheduler::notify_new_task(stream);
|
||||||
|
|
||||||
|
// Perform the computation
|
||||||
|
auto outputs = arr.outputs();
|
||||||
|
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||||
|
|
||||||
|
// Check if we need to detach and signal other arrays waiting for the
|
||||||
|
// result to be ready.
|
||||||
|
if (!arr.is_tracer()) {
|
||||||
|
arr.detach();
|
||||||
|
}
|
||||||
|
if (signal) {
|
||||||
|
arr.event().signal();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task computation done.
|
||||||
|
scheduler::notify_task_completion(stream);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<void()> make_synchronize_task(
|
||||||
|
Stream s,
|
||||||
|
std::shared_ptr<std::promise<void>> p) {
|
||||||
|
return [p = std::move(p)]() { p->set_value(); };
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
18
mlx/backend/common/cpu_impl.h
Normal file
18
mlx/backend/common/cpu_impl.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <future>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
std::function<void()> make_task(array arr, bool signal);
|
||||||
|
std::function<void()> make_synchronize_task(
|
||||||
|
Stream s,
|
||||||
|
std::shared_ptr<std::promise<void>> p);
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
@ -8,10 +8,12 @@ struct Device {
|
|||||||
enum class DeviceType {
|
enum class DeviceType {
|
||||||
cpu,
|
cpu,
|
||||||
gpu,
|
gpu,
|
||||||
|
io,
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr DeviceType cpu = DeviceType::cpu;
|
static constexpr DeviceType cpu = DeviceType::cpu;
|
||||||
static constexpr DeviceType gpu = DeviceType::gpu;
|
static constexpr DeviceType gpu = DeviceType::gpu;
|
||||||
|
static constexpr DeviceType io = DeviceType::io;
|
||||||
|
|
||||||
Device(DeviceType type, int index = 0) : type(type), index(index) {};
|
Device(DeviceType type, int index = 0) : type(type), index(index) {};
|
||||||
|
|
||||||
|
@ -106,6 +106,16 @@ std::tuple<array, array, array, int> vmap_ternary_op(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void Primitive::eval_io(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Primitive::eval_io] Not implemented for ";
|
||||||
|
print(msg);
|
||||||
|
msg << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Primitive::jvp(
|
std::vector<array> Primitive::jvp(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
|
@ -73,6 +73,16 @@ class Primitive {
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) = 0;
|
std::vector<array>& outputs) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Some primitives are computed by an IO device (disk, network, camera etc).
|
||||||
|
*
|
||||||
|
* Like in eval_cpu/gpu the eval_io function is responsible for allocating
|
||||||
|
* the space for the array.
|
||||||
|
*/
|
||||||
|
virtual void eval_io(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Jacobian-vector product.
|
* The Jacobian-vector product.
|
||||||
*/
|
*/
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
#include "mlx/backend/common/cpu_impl.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -36,10 +37,13 @@ Stream new_stream() {
|
|||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
auto p = std::make_shared<std::promise<void>>();
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
switch (s.device.type) {
|
||||||
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
case mlx::core::Device::cpu:
|
||||||
} else {
|
scheduler::enqueue(s, cpu::make_synchronize_task(s, std::move(p)));
|
||||||
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
break;
|
||||||
|
case mlx::core::Device::gpu:
|
||||||
|
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/cpu_impl.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -137,32 +138,17 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
std::vector<std::shared_future<void>> arr_deps;
|
std::vector<std::shared_future<void>> arr_deps;
|
||||||
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
||||||
|
|
||||||
if (arr.primitive().device() == Device::gpu) {
|
switch (arr.primitive().device().type) {
|
||||||
if (!metal::is_available()) {
|
case Device::gpu: {
|
||||||
throw std::runtime_error("Metal GPU is not available.");
|
if (!metal::is_available()) {
|
||||||
|
throw std::runtime_error("Metal GPU is not available.");
|
||||||
|
}
|
||||||
|
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
|
case Device::cpu:
|
||||||
} else {
|
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
|
||||||
auto task = [arr = std::move(arr), stream, signal]() mutable {
|
break;
|
||||||
for (auto& input : arr.inputs()) {
|
|
||||||
if (input.event().valid() &&
|
|
||||||
input.event().stream() != arr.primitive().stream()) {
|
|
||||||
input.event().wait();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scheduler::notify_new_task(stream);
|
|
||||||
auto outputs = arr.outputs();
|
|
||||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
|
||||||
if (!arr.is_tracer()) {
|
|
||||||
arr.detach();
|
|
||||||
}
|
|
||||||
if (signal) {
|
|
||||||
arr.event().signal();
|
|
||||||
}
|
|
||||||
|
|
||||||
scheduler::notify_task_completion(stream);
|
|
||||||
};
|
|
||||||
scheduler::enqueue(stream, std::move(task));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return synchronizer;
|
return synchronizer;
|
||||||
|
@ -133,6 +133,9 @@ std::ostream& operator<<(std::ostream& os, const Device& d) {
|
|||||||
case Device::gpu:
|
case Device::gpu:
|
||||||
os << "gpu";
|
os << "gpu";
|
||||||
break;
|
break;
|
||||||
|
case Device::io:
|
||||||
|
os << "io";
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
os << ", " << d.index << ")";
|
os << ", " << d.index << ")";
|
||||||
return os;
|
return os;
|
||||||
|
Loading…
Reference in New Issue
Block a user