Even Even Faster IO (#1374)

* even more faster io

* make reader pool static

* make python reader thread safe

* one more optimization
This commit is contained in:
Awni Hannun 2024-08-29 16:05:40 -07:00 committed by GitHub
parent 28be4de7c2
commit dba2bd1105
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 73 additions and 35 deletions

View File

@ -202,15 +202,18 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
static Stream io_stream = new_stream(Device::cpu);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto task = [out = out,
offset = offset_,
reader = reader_,
swap_endianness = swap_endianness_]() mutable {
auto read_task = [out = out,
offset = offset_,
reader = reader_,
swap_endianness = swap_endianness_]() mutable {
load(out, offset, reader, swap_endianness);
};
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto signal_task = [out = out, fut = std::move(fut)]() {
fut.wait();
out.event().signal();
};
scheduler::enqueue(io_stream, std::move(task));
scheduler::enqueue(io_stream, std::move(signal_task));
auto& d = metal::device(stream().device);
d.end_encoding(stream().index);
auto command_buffer = d.get_command_buffer(stream().index);

View File

@ -298,11 +298,18 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
/** Load array from file in .npy format */
array load(std::string file, StreamOrDevice s) {
return load(std::make_shared<io::ParallelFileReader>(std::move(file), 4), s);
return load(std::make_shared<io::ParallelFileReader>(std::move(file)), s);
}
namespace io {
ThreadPool& thread_pool() {
static ThreadPool pool_{4};
return pool_;
}
ThreadPool ParallelFileReader::thread_pool_{4};
void ParallelFileReader::read(char* data, size_t n) {
while (n != 0) {
auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
@ -330,11 +337,18 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
};
std::vector<std::future<bool>> futs;
while (n != 0) {
size_t m = std::min(batch_size_, n);
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
data += m;
n -= m;
offset += m;
if (n < batch_size_) {
if (!readfn(offset, n, data)) {
throw std::runtime_error("[read] Unable to read from file.");
}
break;
} else {
size_t m = batch_size_;
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
data += m;
n -= m;
offset += m;
}
}
for (auto& f : futs) {
if (!f.get()) {

View File

@ -14,6 +14,8 @@ namespace mlx::core {
namespace io {
ThreadPool& thread_pool();
class Reader {
public:
virtual bool is_open() const = 0;
@ -43,10 +45,8 @@ class Writer {
class ParallelFileReader : public Reader {
public:
explicit ParallelFileReader(std::string file_path, int num_threads)
: fd_(open(file_path.c_str(), O_RDONLY)),
label_(std::move(file_path)),
thread_pool_(ThreadPool(num_threads)) {}
explicit ParallelFileReader(std::string file_path)
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
~ParallelFileReader() override {
close(fd_);
@ -79,11 +79,10 @@ class ParallelFileReader : public Reader {
}
private:
// 4MB
static constexpr size_t batch_size_ = (1 << 22);
static constexpr size_t batch_size_ = 1 << 25;
static ThreadPool thread_pool_;
int fd_;
std::string label_;
ThreadPool thread_pool_;
};
class FileWriter : public Writer {

View File

@ -147,7 +147,7 @@ SafetensorsLoad load_safetensors(
}
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::ParallelFileReader>(file, 4), s);
return load_safetensors(std::make_shared<io::ParallelFileReader>(file), s);
}
void save_safetensors(

View File

@ -1,3 +1,25 @@
// This code was modified from https://github.com/progschj/ThreadPool
// The original License is copied below:
//
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
// This software is provided 'as-is', without any express or implied
// warranty. In no event will the authors be held liable for any damages
// arising from the use of this software.
//
// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it
// freely, subject to the following restrictions:
//
// 1. The origin of this software must not be misrepresented; you must not
// claim that you wrote the original software. If you use this software
// in a product, an acknowledgment in the product documentation would be
// appreciated but is not required.
//
// 2. Altered source versions must be plainly marked as such, and must not be
// misrepresented as being the original software.
//
// 3. This notice may not be removed or altered from any source
// distribution.
#pragma once
#include <condition_variable>
@ -19,12 +41,8 @@ class ThreadPool {
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector<std::thread> workers;
// the task queue
std::queue<std::function<void()>> tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
@ -63,7 +81,6 @@ auto ThreadPool::enqueue(F&& f, Args&&... args)
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop) {
throw std::runtime_error(
"[ThreadPool::enqueue] Not allowed on stopped ThreadPool");

View File

@ -138,6 +138,21 @@ class PyFileReader : public io::Reader {
void read(char* data, size_t n) override {
nb::gil_scoped_acquire gil;
_read(data, n);
}
void read(char* data, size_t n, size_t offset) override {
nb::gil_scoped_acquire gil;
seek_func_(offset, (int)std::ios_base::beg);
_read(data, n);
}
std::string label() const override {
return "python file object";
}
private:
void _read(char* data, size_t n) {
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
nb::object bytes_read = readinto_func_(nb::handle(memview));
@ -146,16 +161,6 @@ class PyFileReader : public io::Reader {
}
}
void read(char* data, size_t n, size_t offset) override {
seek(offset);
read(data, n);
}
std::string label() const override {
return "python file object";
}
private:
nb::object pyistream_;
nb::object readinto_func_;
nb::object seek_func_;