mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Even Even Faster IO (#1374)
* even more faster io * make reader pool static * make python reader thread safe * one more optimization
This commit is contained in:
parent
28be4de7c2
commit
dba2bd1105
@ -202,15 +202,18 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static Stream io_stream = new_stream(Device::cpu);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [out = out,
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness = swap_endianness_]() mutable {
|
||||
auto read_task = [out = out,
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness = swap_endianness_]() mutable {
|
||||
load(out, offset, reader, swap_endianness);
|
||||
};
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
auto signal_task = [out = out, fut = std::move(fut)]() {
|
||||
fut.wait();
|
||||
out.event().signal();
|
||||
};
|
||||
|
||||
scheduler::enqueue(io_stream, std::move(task));
|
||||
scheduler::enqueue(io_stream, std::move(signal_task));
|
||||
auto& d = metal::device(stream().device);
|
||||
d.end_encoding(stream().index);
|
||||
auto command_buffer = d.get_command_buffer(stream().index);
|
||||
|
@ -298,11 +298,18 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(std::string file, StreamOrDevice s) {
|
||||
return load(std::make_shared<io::ParallelFileReader>(std::move(file), 4), s);
|
||||
return load(std::make_shared<io::ParallelFileReader>(std::move(file)), s);
|
||||
}
|
||||
|
||||
namespace io {
|
||||
|
||||
ThreadPool& thread_pool() {
|
||||
static ThreadPool pool_{4};
|
||||
return pool_;
|
||||
}
|
||||
|
||||
ThreadPool ParallelFileReader::thread_pool_{4};
|
||||
|
||||
void ParallelFileReader::read(char* data, size_t n) {
|
||||
while (n != 0) {
|
||||
auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
|
||||
@ -330,11 +337,18 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
|
||||
};
|
||||
std::vector<std::future<bool>> futs;
|
||||
while (n != 0) {
|
||||
size_t m = std::min(batch_size_, n);
|
||||
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
|
||||
data += m;
|
||||
n -= m;
|
||||
offset += m;
|
||||
if (n < batch_size_) {
|
||||
if (!readfn(offset, n, data)) {
|
||||
throw std::runtime_error("[read] Unable to read from file.");
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
size_t m = batch_size_;
|
||||
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
|
||||
data += m;
|
||||
n -= m;
|
||||
offset += m;
|
||||
}
|
||||
}
|
||||
for (auto& f : futs) {
|
||||
if (!f.get()) {
|
||||
|
@ -14,6 +14,8 @@ namespace mlx::core {
|
||||
|
||||
namespace io {
|
||||
|
||||
ThreadPool& thread_pool();
|
||||
|
||||
class Reader {
|
||||
public:
|
||||
virtual bool is_open() const = 0;
|
||||
@ -43,10 +45,8 @@ class Writer {
|
||||
|
||||
class ParallelFileReader : public Reader {
|
||||
public:
|
||||
explicit ParallelFileReader(std::string file_path, int num_threads)
|
||||
: fd_(open(file_path.c_str(), O_RDONLY)),
|
||||
label_(std::move(file_path)),
|
||||
thread_pool_(ThreadPool(num_threads)) {}
|
||||
explicit ParallelFileReader(std::string file_path)
|
||||
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
|
||||
|
||||
~ParallelFileReader() override {
|
||||
close(fd_);
|
||||
@ -79,11 +79,10 @@ class ParallelFileReader : public Reader {
|
||||
}
|
||||
|
||||
private:
|
||||
// 4MB
|
||||
static constexpr size_t batch_size_ = (1 << 22);
|
||||
static constexpr size_t batch_size_ = 1 << 25;
|
||||
static ThreadPool thread_pool_;
|
||||
int fd_;
|
||||
std::string label_;
|
||||
ThreadPool thread_pool_;
|
||||
};
|
||||
|
||||
class FileWriter : public Writer {
|
||||
|
@ -147,7 +147,7 @@ SafetensorsLoad load_safetensors(
|
||||
}
|
||||
|
||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||
return load_safetensors(std::make_shared<io::ParallelFileReader>(file, 4), s);
|
||||
return load_safetensors(std::make_shared<io::ParallelFileReader>(file), s);
|
||||
}
|
||||
|
||||
void save_safetensors(
|
||||
|
@ -1,3 +1,25 @@
|
||||
// This code was modified from https://github.com/progschj/ThreadPool
|
||||
// The original License is copied below:
|
||||
//
|
||||
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
|
||||
// This software is provided 'as-is', without any express or implied
|
||||
// warranty. In no event will the authors be held liable for any damages
|
||||
// arising from the use of this software.
|
||||
//
|
||||
// Permission is granted to anyone to use this software for any purpose,
|
||||
// including commercial applications, and to alter it and redistribute it
|
||||
// freely, subject to the following restrictions:
|
||||
//
|
||||
// 1. The origin of this software must not be misrepresented; you must not
|
||||
// claim that you wrote the original software. If you use this software
|
||||
// in a product, an acknowledgment in the product documentation would be
|
||||
// appreciated but is not required.
|
||||
//
|
||||
// 2. Altered source versions must be plainly marked as such, and must not be
|
||||
// misrepresented as being the original software.
|
||||
//
|
||||
// 3. This notice may not be removed or altered from any source
|
||||
// distribution.
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable>
|
||||
@ -19,12 +41,8 @@ class ThreadPool {
|
||||
~ThreadPool();
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector<std::thread> workers;
|
||||
// the task queue
|
||||
std::queue<std::function<void()>> tasks;
|
||||
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
@ -63,7 +81,6 @@ auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if (stop) {
|
||||
throw std::runtime_error(
|
||||
"[ThreadPool::enqueue] Not allowed on stopped ThreadPool");
|
||||
|
@ -138,6 +138,21 @@ class PyFileReader : public io::Reader {
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
nb::gil_scoped_acquire gil;
|
||||
_read(data, n);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n, size_t offset) override {
|
||||
nb::gil_scoped_acquire gil;
|
||||
seek_func_(offset, (int)std::ios_base::beg);
|
||||
_read(data, n);
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
void _read(char* data, size_t n) {
|
||||
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
|
||||
nb::object bytes_read = readinto_func_(nb::handle(memview));
|
||||
|
||||
@ -146,16 +161,6 @@ class PyFileReader : public io::Reader {
|
||||
}
|
||||
}
|
||||
|
||||
void read(char* data, size_t n, size_t offset) override {
|
||||
seek(offset);
|
||||
read(data, n);
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
nb::object pyistream_;
|
||||
nb::object readinto_func_;
|
||||
nb::object seek_func_;
|
||||
|
Loading…
Reference in New Issue
Block a user