Change Load to be an IOPrimitive

This commit is contained in:
Angelos Katharopoulos 2024-05-08 18:59:20 -07:00
parent c8e2b42ced
commit b193741050
13 changed files with 101 additions and 29 deletions

View File

@ -51,7 +51,6 @@ DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot) DEFAULT(LogicalNot)
DEFAULT(LogicalAnd) DEFAULT(LogicalAnd)
DEFAULT(LogicalOr) DEFAULT(LogicalOr)

View File

@ -68,7 +68,6 @@ DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(Log) DEFAULT(Log)
DEFAULT(Log1p) DEFAULT(Log1p)
DEFAULT(LogicalNot) DEFAULT(LogicalNot)

View File

@ -3,4 +3,5 @@ target_sources(
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp ${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
) )

View File

@ -0,0 +1,56 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <utility>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
void Load::eval_io(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
array& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
reader_->seek(offset_, std::ios_base::beg);
reader_->read(out.data<char>(), out.nbytes());
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
}
}
}
} // namespace mlx::core

View File

@ -671,10 +671,6 @@ void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq"); binary_op(inputs, out, "leq");
} }
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Log::eval_gpu(const std::vector<array>& inputs, array& out) { void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) { switch (base_) {
case Base::e: case Base::e:

View File

@ -61,7 +61,6 @@ NO_GPU(Greater)
NO_GPU(GreaterEqual) NO_GPU(GreaterEqual)
NO_GPU(Less) NO_GPU(Less)
NO_GPU(LessEqual) NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log) NO_GPU(Log)
NO_GPU(Log1p) NO_GPU(Log1p)
NO_GPU(LogicalNot) NO_GPU(LogicalNot)

View File

@ -32,12 +32,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
array load(std::string file, StreamOrDevice s = {}); array load(std::string file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */ /** Load array map from .safetensors file format */
SafetensorsLoad load_safetensors( SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream);
std::shared_ptr<io::Reader> in_stream, SafetensorsLoad load_safetensors(const std::string& file);
StreamOrDevice s = {});
SafetensorsLoad load_safetensors(
const std::string& file,
StreamOrDevice s = {});
void save_safetensors( void save_safetensors(
std::shared_ptr<io::Writer> in_stream, std::shared_ptr<io::Writer> in_stream,
@ -50,7 +46,7 @@ void save_safetensors(
/** Load array map and metadata from .gguf file format */ /** Load array map and metadata from .gguf file format */
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); GGUFLoad load_gguf(const std::string& file);
void save_gguf( void save_gguf(
std::string file, std::string file,

View File

@ -231,7 +231,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map; return array_map;
} }
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { GGUFLoad load_gguf(const std::string& file) {
gguf_ctx* ctx = gguf_open(file.data()); gguf_ctx* ctx = gguf_open(file.data());
if (!ctx) { if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed"); throw std::runtime_error("[load_gguf] gguf_init failed");

View File

@ -213,7 +213,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
auto loaded_array = array( auto loaded_array = array(
shape, shape,
dtype, dtype,
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness), std::make_shared<Load>(
to_stream(Device::io), in_stream, offset, swap_endianness),
std::vector<array>{}); std::vector<array>{});
if (col_contiguous) { if (col_contiguous) {
loaded_array = transpose(loaded_array, s); loaded_array = transpose(loaded_array, s);

View File

@ -94,9 +94,7 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
} }
/** Load array from reader in safetensor format */ /** Load array from reader in safetensor format */
SafetensorsLoad load_safetensors( SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream) {
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Open and check file // Open and check file
if (!in_stream->good() || !in_stream->is_open()) { if (!in_stream->good() || !in_stream->is_open()) {
@ -138,15 +136,18 @@ SafetensorsLoad load_safetensors(
shape, shape,
type, type,
std::make_shared<Load>( std::make_shared<Load>(
to_stream(s), in_stream, offset + data_offsets.at(0), false), to_stream(Device::io),
in_stream,
offset + data_offsets.at(0),
false),
std::vector<array>{}); std::vector<array>{});
res.insert({item.key(), loaded_array}); res.insert({item.key(), loaded_array});
} }
return {res, metadata_map}; return {res, metadata_map};
} }
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { SafetensorsLoad load_safetensors(const std::string& file) {
return load_safetensors(std::make_shared<io::FileReader>(file), s); return load_safetensors(std::make_shared<io::FileReader>(file));
} }
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */

View File

@ -162,6 +162,26 @@ class UnaryPrimitive : public Primitive {
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
}; };
class IOPrimitive : public Primitive {
/**
* An abstract class for primitives that are doing io which usually are not
* supposed to be evaluated on any other "device".
*/
public:
explicit IOPrimitive(Stream stream) : Primitive(stream) {}
inline void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
throw std::runtime_error("IO primitives cannot be evaluated on CPU");
}
inline void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
throw std::runtime_error("IO primitives cannot be evaluated on GPU");
}
};
class Abs : public UnaryPrimitive { class Abs : public UnaryPrimitive {
public: public:
explicit Abs(Stream stream) : UnaryPrimitive(stream) {}; explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
@ -1074,20 +1094,20 @@ class LessEqual : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class Load : public UnaryPrimitive { class Load : public IOPrimitive {
public: public:
explicit Load( explicit Load(
Stream stream, Stream stream,
std::shared_ptr<io::Reader> reader, std::shared_ptr<io::Reader> reader,
size_t offset, size_t offset,
bool swap_endianness = false) bool swap_endianness = false)
: UnaryPrimitive(stream), : IOPrimitive(stream),
reader_(reader), reader_(reader),
offset_(offset), offset_(offset),
swap_endianness_(swap_endianness) {}; swap_endianness_(swap_endianness) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_io(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_gpu(const std::vector<array>& inputs, array& out) override; override;
DEFINE_PRINT(Load) DEFINE_PRINT(Load)

View File

@ -9,6 +9,7 @@
#include <unordered_set> #include <unordered_set>
#include "mlx/backend/common/cpu_impl.h" #include "mlx/backend/common/cpu_impl.h"
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -149,6 +150,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
case Device::cpu: case Device::cpu:
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal)); scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
break; break;
case Device::io:
scheduler::enqueue(stream, io::make_task(std::move(arr), signal));
break;
} }
} }
return synchronizer; return synchronizer;

View File

@ -162,10 +162,10 @@ std::pair<
std::unordered_map<std::string, std::string>> std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) { mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s); return load_safetensors(nb::cast<std::string>(file));
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s); auto res = load_safetensors(std::make_shared<PyFileReader>(file));
{ {
nb::gil_scoped_release gil; nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) { for (auto& [key, arr] : std::get<0>(res)) {
@ -181,7 +181,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) { GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s); return load_gguf(nb::cast<std::string>(file));
} }
throw std::invalid_argument("[load_gguf] Input must be a string"); throw std::invalid_argument("[load_gguf] Input must be a string");