Compare commits

...

5 Commits

Author SHA1 Message Date
Angelos Katharopoulos
8242d6d5ef Add locks to FileStream 2024-05-08 23:19:27 -07:00
Angelos Katharopoulos
bae159738f Working IO primitives 2024-05-08 22:17:25 -07:00
Angelos Katharopoulos
b193741050 Change Load to be an IOPrimitive 2024-05-08 18:59:20 -07:00
Angelos Katharopoulos
c8e2b42ced Add the io threadpool and task 2024-05-08 18:02:22 -07:00
Angelos Katharopoulos
be36f136de Add io device and cpu::make_task 2024-05-07 16:58:14 -07:00
27 changed files with 607 additions and 58 deletions

View File

@@ -20,6 +20,7 @@ target_sources(
)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/io)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)

View File

@@ -51,7 +51,6 @@ DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)

View File

@@ -55,6 +55,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_impl.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)

View File

@@ -0,0 +1,47 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/cpu_impl.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::cpu {
std::function<void()> make_task(array arr, bool signal) {
return [arr = std::move(arr), signal]() mutable {
auto stream = arr.primitive().stream();
// Wait on inputs coming from different streams/devices.
for (auto& input : arr.inputs()) {
if (input.event().valid() && input.event().stream() != stream) {
input.event().wait();
}
}
// Task computation actually starting.
scheduler::notify_new_task(stream);
// Perform the computation
auto outputs = arr.outputs();
arr.primitive().eval_cpu(arr.inputs(), outputs);
// Check if we need to detach and signal other arrays waiting for the
// result to be ready.
if (!arr.is_tracer()) {
arr.detach();
}
if (signal) {
arr.event().signal();
}
// Task computation done.
scheduler::notify_task_completion(stream);
};
}
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p) {
return [p = std::move(p)]() { p->set_value(); };
}
} // namespace mlx::core::cpu

View File

@@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <functional>
#include <future>
#include <memory>
#include "mlx/array.h"
namespace mlx::core::cpu {
std::function<void()> make_task(array arr, bool signal);
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p);
} // namespace mlx::core::cpu

View File

@@ -68,7 +68,6 @@ DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)

View File

@@ -0,0 +1,7 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
)

View File

@@ -0,0 +1,72 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/io/thread_pool.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::io {
namespace {
detail::ThreadPool& thread_pool() {
static std::unique_ptr<detail::ThreadPool> pool_ptr;
if (pool_ptr == nullptr) {
pool_ptr = std::make_unique<detail::ThreadPool>(4);
}
return *pool_ptr;
}
} // namespace
std::function<void()> make_task(array arr, bool signal) {
return [arr = std::move(arr), signal]() mutable {
auto stream = arr.primitive().stream();
// Wait on inputs coming from different streams/devices.
for (auto& input : arr.inputs()) {
if (input.event().valid() && input.event().stream() != stream) {
input.event().wait();
}
}
// Task computation actually starting.
scheduler::notify_new_task(stream);
// Schedule the computation
auto inputs = arr.inputs();
auto outputs = arr.outputs();
thread_pool().enqueue(
[arr = std::move(arr), inputs, outputs, signal, stream]() mutable {
// Perform the computation
arr.primitive().eval_io(inputs, outputs);
if (!arr.is_tracer()) {
arr.detach();
}
if (signal) {
thread_pool().barrier(
[arr = std::move(arr)]() { arr.event().signal(); });
}
// Task computation done.
scheduler::notify_task_completion(stream);
},
inputs,
outputs);
};
}
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p) {
return [p = std::move(p)]() {
thread_pool().barrier().wait();
p->set_value();
};
}
} // namespace mlx::core::io

18
mlx/backend/io/io_impl.h Normal file
View File

@@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <functional>
#include <future>
#include <memory>
#include "mlx/array.h"
namespace mlx::core::io {
std::function<void()> make_task(array arr, bool signal);
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p);
} // namespace mlx::core::io

View File

@@ -0,0 +1,60 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <utility>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
void Load::eval_io(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
array& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
{
std::lock_guard lock(*reader_);
reader_->seek(offset_, std::ios_base::beg);
reader_->read(out.data<char>(), out.nbytes());
}
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
}
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,216 @@
// Copyright © 2024 Apple Inc.
#include <numeric>
#include "mlx/backend/io/thread_pool.h"
namespace mlx::core::io::detail {
ThreadPool::ThreadPool(int workers)
: task_queues_(workers),
queue_mutexes_(workers),
queue_cvs_(workers),
set_mutexes_(workers),
output_sets_(workers),
stop_(false) {
for (int i = 0; i < workers; i++) {
workers_.emplace_back(&ThreadPool::worker, this, i);
}
}
ThreadPool::~ThreadPool() {
stop_ = true;
for (auto& cv : queue_cvs_) {
cv.notify_one();
}
for (auto& t : workers_) {
if (t.joinable()) {
t.join();
}
}
}
std::future<void> ThreadPool::enqueue(
std::function<void()> task,
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
std::vector<int> barriers;
if (!inputs.empty()) {
for (int i = 0; i < output_sets_.size(); i++) {
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
for (auto& a : inputs) {
if (output_sets_[i].find(a.id()) != output_sets_[i].end()) {
barriers.push_back(i);
break;
}
}
}
}
// Case 1: Barriers is empty so try to add it to the smallest queue
if (barriers.empty()) {
auto min_queue = std::min_element(
task_queues_.begin(),
task_queues_.end(),
[](const auto& left, const auto& right) {
return left.size() < right.size();
});
int worker_idx = std::distance(task_queues_.begin(), min_queue);
add_outputs_to_worker(outputs, worker_idx);
return enqueue(
remove_outputs_when_done(std::move(task), outputs, worker_idx),
worker_idx);
}
// Case 2: Barriers has only one queue so put that into that queue
if (barriers.size() == 1) {
int worker_idx = barriers[0];
add_outputs_to_worker(outputs, worker_idx);
return enqueue(
remove_outputs_when_done(std::move(task), outputs, worker_idx),
worker_idx);
}
// Case 3: We need to add a barrier before our task and add it to the
// smallest queue of the barriers.
auto min_queue = std::min_element(
barriers.begin(), barriers.end(), [this](int left, int right) {
return task_queues_[left].size() < task_queues_[right].size();
});
int worker_idx = *min_queue;
barriers.erase(min_queue);
std::shared_future<void> queue_barrier =
barrier(barriers); // We shouldn't need shared future here
add_outputs_to_worker(outputs, worker_idx);
return enqueue(
remove_outputs_when_done(
[queue_barrier = std::move(queue_barrier),
og_task = std::move(task)]() {
queue_barrier.wait();
og_task();
},
outputs,
worker_idx),
worker_idx);
}
std::future<void> ThreadPool::enqueue(
std::function<void()> task,
int worker_idx) {
std::packaged_task<void()> pt(std::move(task));
std::future<void> result = pt.get_future();
{
std::lock_guard<std::mutex> lock(queue_mutexes_[worker_idx]);
task_queues_[worker_idx].emplace(std::move(pt));
}
queue_cvs_[worker_idx].notify_one();
return result;
}
void ThreadPool::add_outputs_to_worker(
const std::vector<array>& outputs,
int worker_idx) {
if (outputs.size() == 0) {
return;
}
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto& a : outputs) {
output_sets_[worker_idx].insert(a.id());
}
}
std::function<void()> ThreadPool::remove_outputs_when_done(
std::function<void()> task,
const std::vector<array>& outputs,
int worker_idx) {
if (outputs.size() == 0) {
return task;
}
std::vector<std::uintptr_t> output_ids;
for (auto& a : outputs) {
output_ids.push_back(a.id());
}
return [og_task = std::move(task),
ids = std::move(output_ids),
worker_idx,
this]() {
og_task();
{
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto id : ids) {
output_sets_[worker_idx].erase(id);
}
}
};
}
std::future<void> ThreadPool::barrier(
const std::vector<int>& worker_ids,
std::function<void()> on_barrier) {
auto workers = std::make_shared<std::atomic<int>>(worker_ids.size());
auto promise = std::make_shared<std::promise<void>>();
auto future = promise->get_future();
for (auto idx : worker_ids) {
enqueue(
[workers, promise, on_barrier = std::move(on_barrier)]() {
(*workers)--;
if (*workers <= 0) {
on_barrier();
promise->set_value();
}
},
idx);
}
return future;
}
std::future<void> ThreadPool::barrier(const std::vector<int>& worker_ids) {
auto noop = []() {};
return barrier(worker_ids, std::move(noop));
}
std::future<void> ThreadPool::barrier(std::function<void()> on_barrier) {
std::vector<int> worker_ids(workers_.size());
std::iota(worker_ids.begin(), worker_ids.end(), 0);
return barrier(worker_ids, std::move(on_barrier));
}
std::future<void> ThreadPool::barrier() {
auto noop = []() {};
return barrier(std::move(noop));
}
void ThreadPool::worker(int idx) {
while (true) {
std::packaged_task<void()> task;
{
std::unique_lock<std::mutex> lock(queue_mutexes_[idx]);
queue_cvs_[idx].wait(
lock, [this, idx]() { return stop_ || !task_queues_[idx].empty(); });
if (task_queues_[idx].empty()) {
if (stop_) {
break;
} else {
continue;
}
}
task = std::move(task_queues_[idx].front());
task_queues_[idx].pop();
}
try {
task();
} catch (...) {
// do nothing?
}
}
}
} // namespace mlx::core::io::detail

View File

@@ -0,0 +1,52 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <future>
#include <queue>
#include <unordered_set>
#include "mlx/array.h"
namespace mlx::core::io::detail {
class ThreadPool {
public:
explicit ThreadPool(int workers);
~ThreadPool();
ThreadPool(ThreadPool&&) = delete;
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(ThreadPool&&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
std::future<void> enqueue(
std::function<void()> task,
const std::vector<array>& inputs,
const std::vector<array>& outputs);
std::future<void> barrier(
const std::vector<int>& worker_ids,
std::function<void()> on_barrier);
std::future<void> barrier(const std::vector<int>& worker_ids);
std::future<void> barrier(std::function<void()> on_barrier);
std::future<void> barrier();
private:
std::future<void> enqueue(std::function<void()> task, int worker_idx);
void add_outputs_to_worker(const std::vector<array>& outputs, int worker_idx);
std::function<void()> remove_outputs_when_done(
std::function<void()> task,
const std::vector<array>& outputs,
int worker_idx);
void worker(int idx);
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
std::vector<std::mutex> queue_mutexes_;
std::vector<std::condition_variable> queue_cvs_;
std::vector<std::mutex> set_mutexes_;
std::vector<std::unordered_set<std::uintptr_t>> output_sets_;
bool stop_;
std::vector<std::thread> workers_;
};
} // namespace mlx::core::io::detail

View File

@@ -671,10 +671,6 @@ void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:

View File

@@ -61,7 +61,6 @@ NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)

View File

@@ -8,10 +8,12 @@ struct Device {
enum class DeviceType {
cpu,
gpu,
io,
};
static constexpr DeviceType cpu = DeviceType::cpu;
static constexpr DeviceType gpu = DeviceType::gpu;
static constexpr DeviceType io = DeviceType::io;
Device(DeviceType type, int index = 0) : type(type), index(index) {};

View File

@@ -32,12 +32,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
array load(std::string file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
SafetensorsLoad load_safetensors(
const std::string& file,
StreamOrDevice s = {});
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream);
SafetensorsLoad load_safetensors(const std::string& file);
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
@@ -50,7 +46,7 @@ void save_safetensors(
/** Load array map and metadata from .gguf file format */
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
GGUFLoad load_gguf(const std::string& file);
void save_gguf(
std::string file,

View File

@@ -231,7 +231,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map;
}
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
GGUFLoad load_gguf(const std::string& file) {
gguf_ctx* ctx = gguf_open(file.data());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");

View File

@@ -213,7 +213,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
auto loaded_array = array(
shape,
dtype,
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
std::make_shared<Load>(
to_stream(Device::io), in_stream, offset, swap_endianness),
std::vector<array>{});
if (col_contiguous) {
loaded_array = transpose(loaded_array, s);

View File

@@ -20,6 +20,8 @@ class Reader {
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void read(char* data, size_t n) = 0;
virtual std::string label() const = 0;
virtual void lock() = 0;
virtual void unlock() = 0;
};
class Writer {
@@ -67,9 +69,18 @@ class FileReader : public Reader {
return "file " + label_;
}
void lock() override {
is_mutex_.lock();
}
void unlock() override {
is_mutex_.unlock();
}
private:
std::ifstream is_;
std::string label_;
std::mutex is_mutex_;
};
class FileWriter : public Writer {

View File

@@ -94,9 +94,7 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
}
/** Load array from reader in safetensor format */
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
@@ -138,15 +136,18 @@ SafetensorsLoad load_safetensors(
shape,
type,
std::make_shared<Load>(
to_stream(s), in_stream, offset + data_offsets.at(0), false),
to_stream(Device::io),
in_stream,
offset + data_offsets.at(0),
false),
std::vector<array>{});
res.insert({item.key(), loaded_array});
}
return {res, metadata_map};
}
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
SafetensorsLoad load_safetensors(const std::string& file) {
return load_safetensors(std::make_shared<io::FileReader>(file));
}
/** Save array to out stream in .npy format */

View File

@@ -106,6 +106,16 @@ std::tuple<array, array, array, int> vmap_ternary_op(
} // namespace
void Primitive::eval_io(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::ostringstream msg;
msg << "[Primitive::eval_io] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
}
std::vector<array> Primitive::jvp(
const std::vector<array>&,
const std::vector<array>&,

View File

@@ -73,6 +73,16 @@ class Primitive {
const std::vector<array>& inputs,
std::vector<array>& outputs) = 0;
/**
* Some primitives are computed by an IO device (disk, network, camera etc).
*
* Like in eval_cpu/gpu the eval_io function is responsible for allocating
* the space for the array.
*/
virtual void eval_io(
const std::vector<array>& inputs,
std::vector<array>& outputs);
/**
* The Jacobian-vector product.
*/
@@ -152,6 +162,26 @@ class UnaryPrimitive : public Primitive {
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
};
class IOPrimitive : public Primitive {
/**
* An abstract class for primitives that are doing io which usually are not
* supposed to be evaluated on any other "device".
*/
public:
explicit IOPrimitive(Stream stream) : Primitive(stream) {}
inline void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
throw std::runtime_error("IO primitives cannot be evaluated on CPU");
}
inline void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
throw std::runtime_error("IO primitives cannot be evaluated on GPU");
}
};
class Abs : public UnaryPrimitive {
public:
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
@@ -1064,20 +1094,20 @@ class LessEqual : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Load : public UnaryPrimitive {
class Load : public IOPrimitive {
public:
explicit Load(
Stream stream,
std::shared_ptr<io::Reader> reader,
size_t offset,
bool swap_endianness = false)
: UnaryPrimitive(stream),
: IOPrimitive(stream),
reader_(reader),
offset_(offset),
swap_endianness_(swap_endianness) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
void eval_io(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(Load)

View File

@@ -1,6 +1,8 @@
// Copyright © 2023 Apple Inc.
#include "mlx/scheduler.h"
#include "mlx/backend/common/cpu_impl.h"
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core {
@@ -36,10 +38,16 @@ Stream new_stream() {
void synchronize(Stream s) {
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
} else {
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
switch (s.device.type) {
case mlx::core::Device::cpu:
scheduler::enqueue(s, cpu::make_synchronize_task(s, std::move(p)));
break;
case mlx::core::Device::gpu:
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
break;
case mlx::core::Device::io:
scheduler::enqueue(s, io::make_synchronize_task(s, std::move(p)));
break;
}
f.wait();
}

View File

@@ -76,6 +76,7 @@ class Scheduler {
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
}
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
default_streams_.insert({Device::io, new_stream(Device::io)});
}
// Not copyable or moveable

View File

@@ -8,6 +8,8 @@
#include <unordered_map>
#include <unordered_set>
#include "mlx/backend/common/cpu_impl.h"
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
@@ -26,6 +28,7 @@ class Synchronizer : public Primitive {
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
void eval_io(const std::vector<array>&, std::vector<array>&) override {};
DEFINE_PRINT(Synchronize);
};
@@ -137,32 +140,20 @@ array eval_impl(std::vector<array> outputs, bool async) {
std::vector<std::shared_future<void>> arr_deps;
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
if (arr.primitive().device() == Device::gpu) {
if (!metal::is_available()) {
throw std::runtime_error("Metal GPU is not available.");
switch (arr.primitive().device().type) {
case Device::gpu: {
if (!metal::is_available()) {
throw std::runtime_error("Metal GPU is not available.");
}
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
break;
}
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
} else {
auto task = [arr = std::move(arr), stream, signal]() mutable {
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
input.event().wait();
}
}
scheduler::notify_new_task(stream);
auto outputs = arr.outputs();
arr.primitive().eval_cpu(arr.inputs(), outputs);
if (!arr.is_tracer()) {
arr.detach();
}
if (signal) {
arr.event().signal();
}
scheduler::notify_task_completion(stream);
};
scheduler::enqueue(stream, std::move(task));
case Device::cpu:
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
break;
case Device::io:
scheduler::enqueue(stream, io::make_task(std::move(arr), signal));
break;
}
}
return synchronizer;

View File

@@ -133,6 +133,9 @@ std::ostream& operator<<(std::ostream& os, const Device& d) {
case Device::gpu:
os << "gpu";
break;
case Device::io:
os << "io";
break;
}
os << ", " << d.index << ")";
return os;

View File

@@ -150,11 +150,21 @@ class PyFileReader : public io::Reader {
return "python file object";
}
void lock() override {
stream_mutex_.lock();
}
void unlock() override {
stream_mutex_.unlock();
}
private:
nb::object pyistream_;
nb::object readinto_func_;
nb::object seek_func_;
nb::object tell_func_;
std::mutex stream_mutex_;
};
std::pair<
@@ -162,10 +172,10 @@ std::pair<
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
return load_safetensors(nb::cast<std::string>(file));
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
auto res = load_safetensors(std::make_shared<PyFileReader>(file));
{
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
@@ -181,7 +191,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
return load_gguf(nb::cast<std::string>(file));
}
throw std::invalid_argument("[load_gguf] Input must be a string");