mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 03:41:14 +08:00
Change Load to be an IOPrimitive
This commit is contained in:
parent
c8e2b42ced
commit
b193741050
@ -51,7 +51,6 @@ DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
|
@ -68,7 +68,6 @@ DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
|
@ -3,4 +3,5 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
)
|
||||
|
56
mlx/backend/io/primitives.cpp
Normal file
56
mlx/backend/io/primitives.cpp
Normal file
@ -0,0 +1,56 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
struct Elem {
|
||||
uint8_t bytes[scalar_size];
|
||||
};
|
||||
|
||||
Elem* data = reinterpret_cast<Elem*>(data_bytes);
|
||||
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
for (size_t j = 0; j < (scalar_size / 2); j++) {
|
||||
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Load::eval_io(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
array& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -671,10 +671,6 @@ void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
|
@ -61,7 +61,6 @@ NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
|
10
mlx/io.h
10
mlx/io.h
@ -32,12 +32,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
array load(std::string file, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s = {});
|
||||
SafetensorsLoad load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream);
|
||||
SafetensorsLoad load_safetensors(const std::string& file);
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
@ -50,7 +46,7 @@ void save_safetensors(
|
||||
|
||||
/** Load array map and metadata from .gguf file format */
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
GGUFLoad load_gguf(const std::string& file);
|
||||
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
|
@ -231,7 +231,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
return array_map;
|
||||
}
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
GGUFLoad load_gguf(const std::string& file) {
|
||||
gguf_ctx* ctx = gguf_open(file.data());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
|
@ -213,7 +213,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::make_shared<Load>(
|
||||
to_stream(Device::io), in_stream, offset, swap_endianness),
|
||||
std::vector<array>{});
|
||||
if (col_contiguous) {
|
||||
loaded_array = transpose(loaded_array, s);
|
||||
|
@ -94,9 +94,7 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
|
||||
}
|
||||
|
||||
/** Load array from reader in safetensor format */
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s) {
|
||||
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Open and check file
|
||||
if (!in_stream->good() || !in_stream->is_open()) {
|
||||
@ -138,15 +136,18 @@ SafetensorsLoad load_safetensors(
|
||||
shape,
|
||||
type,
|
||||
std::make_shared<Load>(
|
||||
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||
to_stream(Device::io),
|
||||
in_stream,
|
||||
offset + data_offsets.at(0),
|
||||
false),
|
||||
std::vector<array>{});
|
||||
res.insert({item.key(), loaded_array});
|
||||
}
|
||||
return {res, metadata_map};
|
||||
}
|
||||
|
||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||
SafetensorsLoad load_safetensors(const std::string& file) {
|
||||
return load_safetensors(std::make_shared<io::FileReader>(file));
|
||||
}
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
|
@ -162,6 +162,26 @@ class UnaryPrimitive : public Primitive {
|
||||
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
||||
};
|
||||
|
||||
class IOPrimitive : public Primitive {
|
||||
/**
|
||||
* An abstract class for primitives that are doing io which usually are not
|
||||
* supposed to be evaluated on any other "device".
|
||||
*/
|
||||
public:
|
||||
explicit IOPrimitive(Stream stream) : Primitive(stream) {}
|
||||
|
||||
inline void eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override {
|
||||
throw std::runtime_error("IO primitives cannot be evaluated on CPU");
|
||||
}
|
||||
inline void eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override {
|
||||
throw std::runtime_error("IO primitives cannot be evaluated on GPU");
|
||||
}
|
||||
};
|
||||
|
||||
class Abs : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
|
||||
@ -1074,20 +1094,20 @@ class LessEqual : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Load : public UnaryPrimitive {
|
||||
class Load : public IOPrimitive {
|
||||
public:
|
||||
explicit Load(
|
||||
Stream stream,
|
||||
std::shared_ptr<io::Reader> reader,
|
||||
size_t offset,
|
||||
bool swap_endianness = false)
|
||||
: UnaryPrimitive(stream),
|
||||
: IOPrimitive(stream),
|
||||
reader_(reader),
|
||||
offset_(offset),
|
||||
swap_endianness_(swap_endianness) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_io(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(Load)
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/backend/common/cpu_impl.h"
|
||||
#include "mlx/backend/io/io_impl.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
@ -149,6 +150,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
case Device::cpu:
|
||||
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
|
||||
break;
|
||||
case Device::io:
|
||||
scheduler::enqueue(stream, io::make_task(std::move(arr), signal));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return synchronizer;
|
||||
|
@ -162,10 +162,10 @@ std::pair<
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(nb::cast<std::string>(file), s);
|
||||
return load_safetensors(nb::cast<std::string>(file));
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file));
|
||||
{
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : std::get<0>(res)) {
|
||||
@ -181,7 +181,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(nb::cast<std::string>(file), s);
|
||||
return load_gguf(nb::cast<std::string>(file));
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
|
Loading…
Reference in New Issue
Block a user