mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Even Even Faster IO (#1374)
* even more faster io * make reader pool static * make python reader thread safe * one more optimization
This commit is contained in:
parent
28be4de7c2
commit
dba2bd1105
@ -202,15 +202,18 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
static Stream io_stream = new_stream(Device::cpu);
|
static Stream io_stream = new_stream(Device::cpu);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto task = [out = out,
|
auto read_task = [out = out,
|
||||||
offset = offset_,
|
offset = offset_,
|
||||||
reader = reader_,
|
reader = reader_,
|
||||||
swap_endianness = swap_endianness_]() mutable {
|
swap_endianness = swap_endianness_]() mutable {
|
||||||
load(out, offset, reader, swap_endianness);
|
load(out, offset, reader, swap_endianness);
|
||||||
|
};
|
||||||
|
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||||
|
auto signal_task = [out = out, fut = std::move(fut)]() {
|
||||||
|
fut.wait();
|
||||||
out.event().signal();
|
out.event().signal();
|
||||||
};
|
};
|
||||||
|
scheduler::enqueue(io_stream, std::move(signal_task));
|
||||||
scheduler::enqueue(io_stream, std::move(task));
|
|
||||||
auto& d = metal::device(stream().device);
|
auto& d = metal::device(stream().device);
|
||||||
d.end_encoding(stream().index);
|
d.end_encoding(stream().index);
|
||||||
auto command_buffer = d.get_command_buffer(stream().index);
|
auto command_buffer = d.get_command_buffer(stream().index);
|
||||||
|
@ -298,11 +298,18 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
|
|
||||||
/** Load array from file in .npy format */
|
/** Load array from file in .npy format */
|
||||||
array load(std::string file, StreamOrDevice s) {
|
array load(std::string file, StreamOrDevice s) {
|
||||||
return load(std::make_shared<io::ParallelFileReader>(std::move(file), 4), s);
|
return load(std::make_shared<io::ParallelFileReader>(std::move(file)), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace io {
|
namespace io {
|
||||||
|
|
||||||
|
ThreadPool& thread_pool() {
|
||||||
|
static ThreadPool pool_{4};
|
||||||
|
return pool_;
|
||||||
|
}
|
||||||
|
|
||||||
|
ThreadPool ParallelFileReader::thread_pool_{4};
|
||||||
|
|
||||||
void ParallelFileReader::read(char* data, size_t n) {
|
void ParallelFileReader::read(char* data, size_t n) {
|
||||||
while (n != 0) {
|
while (n != 0) {
|
||||||
auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
|
auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
|
||||||
@ -330,12 +337,19 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
|
|||||||
};
|
};
|
||||||
std::vector<std::future<bool>> futs;
|
std::vector<std::future<bool>> futs;
|
||||||
while (n != 0) {
|
while (n != 0) {
|
||||||
size_t m = std::min(batch_size_, n);
|
if (n < batch_size_) {
|
||||||
|
if (!readfn(offset, n, data)) {
|
||||||
|
throw std::runtime_error("[read] Unable to read from file.");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
size_t m = batch_size_;
|
||||||
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
|
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
|
||||||
data += m;
|
data += m;
|
||||||
n -= m;
|
n -= m;
|
||||||
offset += m;
|
offset += m;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for (auto& f : futs) {
|
for (auto& f : futs) {
|
||||||
if (!f.get()) {
|
if (!f.get()) {
|
||||||
throw std::runtime_error("[read] Unable to read from file.");
|
throw std::runtime_error("[read] Unable to read from file.");
|
||||||
|
@ -14,6 +14,8 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace io {
|
namespace io {
|
||||||
|
|
||||||
|
ThreadPool& thread_pool();
|
||||||
|
|
||||||
class Reader {
|
class Reader {
|
||||||
public:
|
public:
|
||||||
virtual bool is_open() const = 0;
|
virtual bool is_open() const = 0;
|
||||||
@ -43,10 +45,8 @@ class Writer {
|
|||||||
|
|
||||||
class ParallelFileReader : public Reader {
|
class ParallelFileReader : public Reader {
|
||||||
public:
|
public:
|
||||||
explicit ParallelFileReader(std::string file_path, int num_threads)
|
explicit ParallelFileReader(std::string file_path)
|
||||||
: fd_(open(file_path.c_str(), O_RDONLY)),
|
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
|
||||||
label_(std::move(file_path)),
|
|
||||||
thread_pool_(ThreadPool(num_threads)) {}
|
|
||||||
|
|
||||||
~ParallelFileReader() override {
|
~ParallelFileReader() override {
|
||||||
close(fd_);
|
close(fd_);
|
||||||
@ -79,11 +79,10 @@ class ParallelFileReader : public Reader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// 4MB
|
static constexpr size_t batch_size_ = 1 << 25;
|
||||||
static constexpr size_t batch_size_ = (1 << 22);
|
static ThreadPool thread_pool_;
|
||||||
int fd_;
|
int fd_;
|
||||||
std::string label_;
|
std::string label_;
|
||||||
ThreadPool thread_pool_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class FileWriter : public Writer {
|
class FileWriter : public Writer {
|
||||||
|
@ -147,7 +147,7 @@ SafetensorsLoad load_safetensors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||||
return load_safetensors(std::make_shared<io::ParallelFileReader>(file, 4), s);
|
return load_safetensors(std::make_shared<io::ParallelFileReader>(file), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
|
@ -1,3 +1,25 @@
|
|||||||
|
// This code was modified from https://github.com/progschj/ThreadPool
|
||||||
|
// The original License is copied below:
|
||||||
|
//
|
||||||
|
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
|
||||||
|
// This software is provided 'as-is', without any express or implied
|
||||||
|
// warranty. In no event will the authors be held liable for any damages
|
||||||
|
// arising from the use of this software.
|
||||||
|
//
|
||||||
|
// Permission is granted to anyone to use this software for any purpose,
|
||||||
|
// including commercial applications, and to alter it and redistribute it
|
||||||
|
// freely, subject to the following restrictions:
|
||||||
|
//
|
||||||
|
// 1. The origin of this software must not be misrepresented; you must not
|
||||||
|
// claim that you wrote the original software. If you use this software
|
||||||
|
// in a product, an acknowledgment in the product documentation would be
|
||||||
|
// appreciated but is not required.
|
||||||
|
//
|
||||||
|
// 2. Altered source versions must be plainly marked as such, and must not be
|
||||||
|
// misrepresented as being the original software.
|
||||||
|
//
|
||||||
|
// 3. This notice may not be removed or altered from any source
|
||||||
|
// distribution.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
@ -19,12 +41,8 @@ class ThreadPool {
|
|||||||
~ThreadPool();
|
~ThreadPool();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// need to keep track of threads so we can join them
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
// the task queue
|
|
||||||
std::queue<std::function<void()>> tasks;
|
std::queue<std::function<void()>> tasks;
|
||||||
|
|
||||||
// synchronization
|
|
||||||
std::mutex queue_mutex;
|
std::mutex queue_mutex;
|
||||||
std::condition_variable condition;
|
std::condition_variable condition;
|
||||||
bool stop;
|
bool stop;
|
||||||
@ -63,7 +81,6 @@ auto ThreadPool::enqueue(F&& f, Args&&... args)
|
|||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||||
|
|
||||||
// don't allow enqueueing after stopping the pool
|
|
||||||
if (stop) {
|
if (stop) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[ThreadPool::enqueue] Not allowed on stopped ThreadPool");
|
"[ThreadPool::enqueue] Not allowed on stopped ThreadPool");
|
||||||
|
@ -138,6 +138,21 @@ class PyFileReader : public io::Reader {
|
|||||||
|
|
||||||
void read(char* data, size_t n) override {
|
void read(char* data, size_t n) override {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
_read(data, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void read(char* data, size_t n, size_t offset) override {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
seek_func_(offset, (int)std::ios_base::beg);
|
||||||
|
_read(data, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string label() const override {
|
||||||
|
return "python file object";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void _read(char* data, size_t n) {
|
||||||
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
|
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
|
||||||
nb::object bytes_read = readinto_func_(nb::handle(memview));
|
nb::object bytes_read = readinto_func_(nb::handle(memview));
|
||||||
|
|
||||||
@ -146,16 +161,6 @@ class PyFileReader : public io::Reader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void read(char* data, size_t n, size_t offset) override {
|
|
||||||
seek(offset);
|
|
||||||
read(data, n);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string label() const override {
|
|
||||||
return "python file object";
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
nb::object pyistream_;
|
nb::object pyistream_;
|
||||||
nb::object readinto_func_;
|
nb::object readinto_func_;
|
||||||
nb::object seek_func_;
|
nb::object seek_func_;
|
||||||
|
Loading…
Reference in New Issue
Block a user