Change Load to be an IOPrimitive

This commit is contained in:
Angelos Katharopoulos 2024-05-08 18:59:20 -07:00
parent c8e2b42ced
commit b193741050
13 changed files with 101 additions and 29 deletions

View File

@ -51,7 +51,6 @@ DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)

View File

@ -68,7 +68,6 @@ DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)

View File

@ -3,4 +3,5 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
)

View File

@ -0,0 +1,56 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <utility>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
void Load::eval_io(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
array& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
reader_->seek(offset_, std::ios_base::beg);
reader_->read(out.data<char>(), out.nbytes());
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
}
}
}
} // namespace mlx::core

View File

@ -671,10 +671,6 @@ void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:

View File

@ -61,7 +61,6 @@ NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)

View File

@ -32,12 +32,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
array load(std::string file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
SafetensorsLoad load_safetensors(
const std::string& file,
StreamOrDevice s = {});
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream);
SafetensorsLoad load_safetensors(const std::string& file);
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
@ -50,7 +46,7 @@ void save_safetensors(
/** Load array map and metadata from .gguf file format */
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
GGUFLoad load_gguf(const std::string& file);
void save_gguf(
std::string file,

View File

@ -231,7 +231,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map;
}
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
GGUFLoad load_gguf(const std::string& file) {
gguf_ctx* ctx = gguf_open(file.data());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");

View File

@ -213,7 +213,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
auto loaded_array = array(
shape,
dtype,
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
std::make_shared<Load>(
to_stream(Device::io), in_stream, offset, swap_endianness),
std::vector<array>{});
if (col_contiguous) {
loaded_array = transpose(loaded_array, s);

View File

@ -94,9 +94,7 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
}
/** Load array from reader in safetensor format */
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
@ -138,15 +136,18 @@ SafetensorsLoad load_safetensors(
shape,
type,
std::make_shared<Load>(
to_stream(s), in_stream, offset + data_offsets.at(0), false),
to_stream(Device::io),
in_stream,
offset + data_offsets.at(0),
false),
std::vector<array>{});
res.insert({item.key(), loaded_array});
}
return {res, metadata_map};
}
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
SafetensorsLoad load_safetensors(const std::string& file) {
return load_safetensors(std::make_shared<io::FileReader>(file));
}
/** Save array to out stream in .npy format */

View File

@ -162,6 +162,26 @@ class UnaryPrimitive : public Primitive {
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
};
class IOPrimitive : public Primitive {
/**
* An abstract class for primitives that are doing io which usually are not
* supposed to be evaluated on any other "device".
*/
public:
explicit IOPrimitive(Stream stream) : Primitive(stream) {}
inline void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
throw std::runtime_error("IO primitives cannot be evaluated on CPU");
}
inline void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
throw std::runtime_error("IO primitives cannot be evaluated on GPU");
}
};
class Abs : public UnaryPrimitive {
public:
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
@ -1074,20 +1094,20 @@ class LessEqual : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Load : public UnaryPrimitive {
class Load : public IOPrimitive {
public:
explicit Load(
Stream stream,
std::shared_ptr<io::Reader> reader,
size_t offset,
bool swap_endianness = false)
: UnaryPrimitive(stream),
: IOPrimitive(stream),
reader_(reader),
offset_(offset),
swap_endianness_(swap_endianness) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
void eval_io(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(Load)

View File

@ -9,6 +9,7 @@
#include <unordered_set>
#include "mlx/backend/common/cpu_impl.h"
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
@ -149,6 +150,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
case Device::cpu:
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
break;
case Device::io:
scheduler::enqueue(stream, io::make_task(std::move(arr), signal));
break;
}
}
return synchronizer;

View File

@ -162,10 +162,10 @@ std::pair<
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
return load_safetensors(nb::cast<std::string>(file));
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
auto res = load_safetensors(std::make_shared<PyFileReader>(file));
{
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
@ -181,7 +181,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
return load_gguf(nb::cast<std::string>(file));
}
throw std::invalid_argument("[load_gguf] Input must be a string");