mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 00:54:37 +08:00
Compare commits
5 Commits
sockets-di
...
io-dev
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8242d6d5ef | ||
![]() |
bae159738f | ||
![]() |
b193741050 | ||
![]() |
c8e2b42ced | ||
![]() |
be36f136de |
@@ -20,6 +20,7 @@ target_sources(
|
||||
)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/io)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
|
@@ -51,7 +51,6 @@ DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
|
@@ -55,6 +55,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_impl.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
|
47
mlx/backend/common/cpu_impl.cpp
Normal file
47
mlx/backend/common/cpu_impl.cpp
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/cpu_impl.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal) {
|
||||
return [arr = std::move(arr), signal]() mutable {
|
||||
auto stream = arr.primitive().stream();
|
||||
|
||||
// Wait on inputs coming from different streams/devices.
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() && input.event().stream() != stream) {
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
||||
// Task computation actually starting.
|
||||
scheduler::notify_new_task(stream);
|
||||
|
||||
// Perform the computation
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
|
||||
// Check if we need to detach and signal other arrays waiting for the
|
||||
// result to be ready.
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
if (signal) {
|
||||
arr.event().signal();
|
||||
}
|
||||
|
||||
// Task computation done.
|
||||
scheduler::notify_task_completion(stream);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
return [p = std::move(p)]() { p->set_value(); };
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
18
mlx/backend/common/cpu_impl.h
Normal file
18
mlx/backend/common/cpu_impl.h
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal);
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -68,7 +68,6 @@ DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
|
7
mlx/backend/io/CMakeLists.txt
Normal file
7
mlx/backend/io/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
)
|
72
mlx/backend/io/io_impl.cpp
Normal file
72
mlx/backend/io/io_impl.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/io/io_impl.h"
|
||||
#include "mlx/backend/io/thread_pool.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::io {
|
||||
|
||||
namespace {
|
||||
|
||||
detail::ThreadPool& thread_pool() {
|
||||
static std::unique_ptr<detail::ThreadPool> pool_ptr;
|
||||
|
||||
if (pool_ptr == nullptr) {
|
||||
pool_ptr = std::make_unique<detail::ThreadPool>(4);
|
||||
}
|
||||
|
||||
return *pool_ptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal) {
|
||||
return [arr = std::move(arr), signal]() mutable {
|
||||
auto stream = arr.primitive().stream();
|
||||
|
||||
// Wait on inputs coming from different streams/devices.
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() && input.event().stream() != stream) {
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
||||
// Task computation actually starting.
|
||||
scheduler::notify_new_task(stream);
|
||||
|
||||
// Schedule the computation
|
||||
auto inputs = arr.inputs();
|
||||
auto outputs = arr.outputs();
|
||||
thread_pool().enqueue(
|
||||
[arr = std::move(arr), inputs, outputs, signal, stream]() mutable {
|
||||
// Perform the computation
|
||||
arr.primitive().eval_io(inputs, outputs);
|
||||
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
thread_pool().barrier(
|
||||
[arr = std::move(arr)]() { arr.event().signal(); });
|
||||
}
|
||||
|
||||
// Task computation done.
|
||||
scheduler::notify_task_completion(stream);
|
||||
},
|
||||
inputs,
|
||||
outputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
return [p = std::move(p)]() {
|
||||
thread_pool().barrier().wait();
|
||||
p->set_value();
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::io
|
18
mlx/backend/io/io_impl.h
Normal file
18
mlx/backend/io/io_impl.h
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::io {
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal);
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::io
|
60
mlx/backend/io/primitives.cpp
Normal file
60
mlx/backend/io/primitives.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
struct Elem {
|
||||
uint8_t bytes[scalar_size];
|
||||
};
|
||||
|
||||
Elem* data = reinterpret_cast<Elem*>(data_bytes);
|
||||
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
for (size_t j = 0; j < (scalar_size / 2); j++) {
|
||||
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Load::eval_io(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
array& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
{
|
||||
std::lock_guard lock(*reader_);
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
}
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
216
mlx/backend/io/thread_pool.cpp
Normal file
216
mlx/backend/io/thread_pool.cpp
Normal file
@@ -0,0 +1,216 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/io/thread_pool.h"
|
||||
|
||||
namespace mlx::core::io::detail {
|
||||
|
||||
ThreadPool::ThreadPool(int workers)
|
||||
: task_queues_(workers),
|
||||
queue_mutexes_(workers),
|
||||
queue_cvs_(workers),
|
||||
set_mutexes_(workers),
|
||||
output_sets_(workers),
|
||||
stop_(false) {
|
||||
for (int i = 0; i < workers; i++) {
|
||||
workers_.emplace_back(&ThreadPool::worker, this, i);
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
stop_ = true;
|
||||
for (auto& cv : queue_cvs_) {
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
for (auto& t : workers_) {
|
||||
if (t.joinable()) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::enqueue(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<int> barriers;
|
||||
if (!inputs.empty()) {
|
||||
for (int i = 0; i < output_sets_.size(); i++) {
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
|
||||
|
||||
for (auto& a : inputs) {
|
||||
if (output_sets_[i].find(a.id()) != output_sets_[i].end()) {
|
||||
barriers.push_back(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 1: Barriers is empty so try to add it to the smallest queue
|
||||
if (barriers.empty()) {
|
||||
auto min_queue = std::min_element(
|
||||
task_queues_.begin(),
|
||||
task_queues_.end(),
|
||||
[](const auto& left, const auto& right) {
|
||||
return left.size() < right.size();
|
||||
});
|
||||
int worker_idx = std::distance(task_queues_.begin(), min_queue);
|
||||
|
||||
add_outputs_to_worker(outputs, worker_idx);
|
||||
return enqueue(
|
||||
remove_outputs_when_done(std::move(task), outputs, worker_idx),
|
||||
worker_idx);
|
||||
}
|
||||
|
||||
// Case 2: Barriers has only one queue so put that into that queue
|
||||
if (barriers.size() == 1) {
|
||||
int worker_idx = barriers[0];
|
||||
add_outputs_to_worker(outputs, worker_idx);
|
||||
return enqueue(
|
||||
remove_outputs_when_done(std::move(task), outputs, worker_idx),
|
||||
worker_idx);
|
||||
}
|
||||
|
||||
// Case 3: We need to add a barrier before our task and add it to the
|
||||
// smallest queue of the barriers.
|
||||
auto min_queue = std::min_element(
|
||||
barriers.begin(), barriers.end(), [this](int left, int right) {
|
||||
return task_queues_[left].size() < task_queues_[right].size();
|
||||
});
|
||||
int worker_idx = *min_queue;
|
||||
barriers.erase(min_queue);
|
||||
std::shared_future<void> queue_barrier =
|
||||
barrier(barriers); // We shouldn't need shared future here
|
||||
add_outputs_to_worker(outputs, worker_idx);
|
||||
return enqueue(
|
||||
remove_outputs_when_done(
|
||||
[queue_barrier = std::move(queue_barrier),
|
||||
og_task = std::move(task)]() {
|
||||
queue_barrier.wait();
|
||||
og_task();
|
||||
},
|
||||
outputs,
|
||||
worker_idx),
|
||||
worker_idx);
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::enqueue(
|
||||
std::function<void()> task,
|
||||
int worker_idx) {
|
||||
std::packaged_task<void()> pt(std::move(task));
|
||||
std::future<void> result = pt.get_future();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(queue_mutexes_[worker_idx]);
|
||||
task_queues_[worker_idx].emplace(std::move(pt));
|
||||
}
|
||||
queue_cvs_[worker_idx].notify_one();
|
||||
return result;
|
||||
}
|
||||
|
||||
void ThreadPool::add_outputs_to_worker(
|
||||
const std::vector<array>& outputs,
|
||||
int worker_idx) {
|
||||
if (outputs.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||
for (auto& a : outputs) {
|
||||
output_sets_[worker_idx].insert(a.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void()> ThreadPool::remove_outputs_when_done(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& outputs,
|
||||
int worker_idx) {
|
||||
if (outputs.size() == 0) {
|
||||
return task;
|
||||
}
|
||||
|
||||
std::vector<std::uintptr_t> output_ids;
|
||||
for (auto& a : outputs) {
|
||||
output_ids.push_back(a.id());
|
||||
}
|
||||
|
||||
return [og_task = std::move(task),
|
||||
ids = std::move(output_ids),
|
||||
worker_idx,
|
||||
this]() {
|
||||
og_task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||
for (auto id : ids) {
|
||||
output_sets_[worker_idx].erase(id);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier(
|
||||
const std::vector<int>& worker_ids,
|
||||
std::function<void()> on_barrier) {
|
||||
auto workers = std::make_shared<std::atomic<int>>(worker_ids.size());
|
||||
auto promise = std::make_shared<std::promise<void>>();
|
||||
auto future = promise->get_future();
|
||||
|
||||
for (auto idx : worker_ids) {
|
||||
enqueue(
|
||||
[workers, promise, on_barrier = std::move(on_barrier)]() {
|
||||
(*workers)--;
|
||||
if (*workers <= 0) {
|
||||
on_barrier();
|
||||
promise->set_value();
|
||||
}
|
||||
},
|
||||
idx);
|
||||
}
|
||||
|
||||
return future;
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier(const std::vector<int>& worker_ids) {
|
||||
auto noop = []() {};
|
||||
return barrier(worker_ids, std::move(noop));
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier(std::function<void()> on_barrier) {
|
||||
std::vector<int> worker_ids(workers_.size());
|
||||
std::iota(worker_ids.begin(), worker_ids.end(), 0);
|
||||
return barrier(worker_ids, std::move(on_barrier));
|
||||
}
|
||||
|
||||
std::future<void> ThreadPool::barrier() {
|
||||
auto noop = []() {};
|
||||
return barrier(std::move(noop));
|
||||
}
|
||||
|
||||
void ThreadPool::worker(int idx) {
|
||||
while (true) {
|
||||
std::packaged_task<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutexes_[idx]);
|
||||
queue_cvs_[idx].wait(
|
||||
lock, [this, idx]() { return stop_ || !task_queues_[idx].empty(); });
|
||||
if (task_queues_[idx].empty()) {
|
||||
if (stop_) {
|
||||
break;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
task = std::move(task_queues_[idx].front());
|
||||
task_queues_[idx].pop();
|
||||
}
|
||||
try {
|
||||
task();
|
||||
} catch (...) {
|
||||
// do nothing?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::io::detail
|
52
mlx/backend/io/thread_pool.h
Normal file
52
mlx/backend/io/thread_pool.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::io::detail {
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
explicit ThreadPool(int workers);
|
||||
~ThreadPool();
|
||||
|
||||
ThreadPool(ThreadPool&&) = delete;
|
||||
ThreadPool(const ThreadPool&) = delete;
|
||||
ThreadPool& operator=(ThreadPool&&) = delete;
|
||||
ThreadPool& operator=(const ThreadPool&) = delete;
|
||||
|
||||
std::future<void> enqueue(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs);
|
||||
std::future<void> barrier(
|
||||
const std::vector<int>& worker_ids,
|
||||
std::function<void()> on_barrier);
|
||||
std::future<void> barrier(const std::vector<int>& worker_ids);
|
||||
std::future<void> barrier(std::function<void()> on_barrier);
|
||||
std::future<void> barrier();
|
||||
|
||||
private:
|
||||
std::future<void> enqueue(std::function<void()> task, int worker_idx);
|
||||
void add_outputs_to_worker(const std::vector<array>& outputs, int worker_idx);
|
||||
std::function<void()> remove_outputs_when_done(
|
||||
std::function<void()> task,
|
||||
const std::vector<array>& outputs,
|
||||
int worker_idx);
|
||||
void worker(int idx);
|
||||
|
||||
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
|
||||
std::vector<std::mutex> queue_mutexes_;
|
||||
std::vector<std::condition_variable> queue_cvs_;
|
||||
std::vector<std::mutex> set_mutexes_;
|
||||
std::vector<std::unordered_set<std::uintptr_t>> output_sets_;
|
||||
bool stop_;
|
||||
std::vector<std::thread> workers_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::io::detail
|
@@ -671,10 +671,6 @@ void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
|
@@ -61,7 +61,6 @@ NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
|
@@ -8,10 +8,12 @@ struct Device {
|
||||
enum class DeviceType {
|
||||
cpu,
|
||||
gpu,
|
||||
io,
|
||||
};
|
||||
|
||||
static constexpr DeviceType cpu = DeviceType::cpu;
|
||||
static constexpr DeviceType gpu = DeviceType::gpu;
|
||||
static constexpr DeviceType io = DeviceType::io;
|
||||
|
||||
Device(DeviceType type, int index = 0) : type(type), index(index) {};
|
||||
|
||||
|
10
mlx/io.h
10
mlx/io.h
@@ -32,12 +32,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
array load(std::string file, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s = {});
|
||||
SafetensorsLoad load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream);
|
||||
SafetensorsLoad load_safetensors(const std::string& file);
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
@@ -50,7 +46,7 @@ void save_safetensors(
|
||||
|
||||
/** Load array map and metadata from .gguf file format */
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
GGUFLoad load_gguf(const std::string& file);
|
||||
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
|
@@ -231,7 +231,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
return array_map;
|
||||
}
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
GGUFLoad load_gguf(const std::string& file) {
|
||||
gguf_ctx* ctx = gguf_open(file.data());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
|
@@ -213,7 +213,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::make_shared<Load>(
|
||||
to_stream(Device::io), in_stream, offset, swap_endianness),
|
||||
std::vector<array>{});
|
||||
if (col_contiguous) {
|
||||
loaded_array = transpose(loaded_array, s);
|
||||
|
@@ -20,6 +20,8 @@ class Reader {
|
||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||
virtual void read(char* data, size_t n) = 0;
|
||||
virtual std::string label() const = 0;
|
||||
virtual void lock() = 0;
|
||||
virtual void unlock() = 0;
|
||||
};
|
||||
|
||||
class Writer {
|
||||
@@ -67,9 +69,18 @@ class FileReader : public Reader {
|
||||
return "file " + label_;
|
||||
}
|
||||
|
||||
void lock() override {
|
||||
is_mutex_.lock();
|
||||
}
|
||||
|
||||
void unlock() override {
|
||||
is_mutex_.unlock();
|
||||
}
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
std::string label_;
|
||||
std::mutex is_mutex_;
|
||||
};
|
||||
|
||||
class FileWriter : public Writer {
|
||||
|
@@ -94,9 +94,7 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
|
||||
}
|
||||
|
||||
/** Load array from reader in safetensor format */
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s) {
|
||||
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Open and check file
|
||||
if (!in_stream->good() || !in_stream->is_open()) {
|
||||
@@ -138,15 +136,18 @@ SafetensorsLoad load_safetensors(
|
||||
shape,
|
||||
type,
|
||||
std::make_shared<Load>(
|
||||
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||
to_stream(Device::io),
|
||||
in_stream,
|
||||
offset + data_offsets.at(0),
|
||||
false),
|
||||
std::vector<array>{});
|
||||
res.insert({item.key(), loaded_array});
|
||||
}
|
||||
return {res, metadata_map};
|
||||
}
|
||||
|
||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||
SafetensorsLoad load_safetensors(const std::string& file) {
|
||||
return load_safetensors(std::make_shared<io::FileReader>(file));
|
||||
}
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
|
@@ -106,6 +106,16 @@ std::tuple<array, array, array, int> vmap_ternary_op(
|
||||
|
||||
} // namespace
|
||||
|
||||
void Primitive::eval_io(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::eval_io] Not implemented for ";
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<array> Primitive::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
|
@@ -73,6 +73,16 @@ class Primitive {
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) = 0;
|
||||
|
||||
/**
|
||||
* Some primitives are computed by an IO device (disk, network, camera etc).
|
||||
*
|
||||
* Like in eval_cpu/gpu the eval_io function is responsible for allocating
|
||||
* the space for the array.
|
||||
*/
|
||||
virtual void eval_io(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs);
|
||||
|
||||
/**
|
||||
* The Jacobian-vector product.
|
||||
*/
|
||||
@@ -152,6 +162,26 @@ class UnaryPrimitive : public Primitive {
|
||||
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
||||
};
|
||||
|
||||
class IOPrimitive : public Primitive {
|
||||
/**
|
||||
* An abstract class for primitives that are doing io which usually are not
|
||||
* supposed to be evaluated on any other "device".
|
||||
*/
|
||||
public:
|
||||
explicit IOPrimitive(Stream stream) : Primitive(stream) {}
|
||||
|
||||
inline void eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override {
|
||||
throw std::runtime_error("IO primitives cannot be evaluated on CPU");
|
||||
}
|
||||
inline void eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override {
|
||||
throw std::runtime_error("IO primitives cannot be evaluated on GPU");
|
||||
}
|
||||
};
|
||||
|
||||
class Abs : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
|
||||
@@ -1064,20 +1094,20 @@ class LessEqual : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Load : public UnaryPrimitive {
|
||||
class Load : public IOPrimitive {
|
||||
public:
|
||||
explicit Load(
|
||||
Stream stream,
|
||||
std::shared_ptr<io::Reader> reader,
|
||||
size_t offset,
|
||||
bool swap_endianness = false)
|
||||
: UnaryPrimitive(stream),
|
||||
: IOPrimitive(stream),
|
||||
reader_(reader),
|
||||
offset_(offset),
|
||||
swap_endianness_(swap_endianness) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_io(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(Load)
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/backend/common/cpu_impl.h"
|
||||
#include "mlx/backend/io/io_impl.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -36,10 +38,16 @@ Stream new_stream() {
|
||||
void synchronize(Stream s) {
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
||||
} else {
|
||||
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
||||
switch (s.device.type) {
|
||||
case mlx::core::Device::cpu:
|
||||
scheduler::enqueue(s, cpu::make_synchronize_task(s, std::move(p)));
|
||||
break;
|
||||
case mlx::core::Device::gpu:
|
||||
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
||||
break;
|
||||
case mlx::core::Device::io:
|
||||
scheduler::enqueue(s, io::make_synchronize_task(s, std::move(p)));
|
||||
break;
|
||||
}
|
||||
f.wait();
|
||||
}
|
||||
|
@@ -76,6 +76,7 @@ class Scheduler {
|
||||
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
||||
}
|
||||
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
||||
default_streams_.insert({Device::io, new_stream(Device::io)});
|
||||
}
|
||||
|
||||
// Not copyable or moveable
|
||||
|
@@ -8,6 +8,8 @@
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/backend/common/cpu_impl.h"
|
||||
#include "mlx/backend/io/io_impl.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -26,6 +28,7 @@ class Synchronizer : public Primitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
void eval_io(const std::vector<array>&, std::vector<array>&) override {};
|
||||
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
@@ -137,32 +140,20 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
std::vector<std::shared_future<void>> arr_deps;
|
||||
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
||||
|
||||
if (arr.primitive().device() == Device::gpu) {
|
||||
if (!metal::is_available()) {
|
||||
throw std::runtime_error("Metal GPU is not available.");
|
||||
switch (arr.primitive().device().type) {
|
||||
case Device::gpu: {
|
||||
if (!metal::is_available()) {
|
||||
throw std::runtime_error("Metal GPU is not available.");
|
||||
}
|
||||
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
|
||||
break;
|
||||
}
|
||||
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
|
||||
} else {
|
||||
auto task = [arr = std::move(arr), stream, signal]() mutable {
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
scheduler::notify_new_task(stream);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
if (signal) {
|
||||
arr.event().signal();
|
||||
}
|
||||
|
||||
scheduler::notify_task_completion(stream);
|
||||
};
|
||||
scheduler::enqueue(stream, std::move(task));
|
||||
case Device::cpu:
|
||||
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
|
||||
break;
|
||||
case Device::io:
|
||||
scheduler::enqueue(stream, io::make_task(std::move(arr), signal));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return synchronizer;
|
||||
|
@@ -133,6 +133,9 @@ std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||
case Device::gpu:
|
||||
os << "gpu";
|
||||
break;
|
||||
case Device::io:
|
||||
os << "io";
|
||||
break;
|
||||
}
|
||||
os << ", " << d.index << ")";
|
||||
return os;
|
||||
|
@@ -150,11 +150,21 @@ class PyFileReader : public io::Reader {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
void lock() override {
|
||||
stream_mutex_.lock();
|
||||
}
|
||||
|
||||
void unlock() override {
|
||||
stream_mutex_.unlock();
|
||||
}
|
||||
|
||||
private:
|
||||
nb::object pyistream_;
|
||||
nb::object readinto_func_;
|
||||
nb::object seek_func_;
|
||||
nb::object tell_func_;
|
||||
|
||||
std::mutex stream_mutex_;
|
||||
};
|
||||
|
||||
std::pair<
|
||||
@@ -162,10 +172,10 @@ std::pair<
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(nb::cast<std::string>(file), s);
|
||||
return load_safetensors(nb::cast<std::string>(file));
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file));
|
||||
{
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : std::get<0>(res)) {
|
||||
@@ -181,7 +191,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(nb::cast<std::string>(file), s);
|
||||
return load_gguf(nb::cast<std::string>(file));
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
|
Reference in New Issue
Block a user