mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 12:13:21 +08:00
Change Load to be an IOPrimitive
This commit is contained in:
parent
c8e2b42ced
commit
b193741050
@ -51,7 +51,6 @@ DEFAULT(Greater)
|
|||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
DEFAULT(Less)
|
DEFAULT(Less)
|
||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
|
||||||
DEFAULT(LogicalNot)
|
DEFAULT(LogicalNot)
|
||||||
DEFAULT(LogicalAnd)
|
DEFAULT(LogicalAnd)
|
||||||
DEFAULT(LogicalOr)
|
DEFAULT(LogicalOr)
|
||||||
|
@ -68,7 +68,6 @@ DEFAULT(Greater)
|
|||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
DEFAULT(Less)
|
DEFAULT(Less)
|
||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
|
||||||
DEFAULT(Log)
|
DEFAULT(Log)
|
||||||
DEFAULT(Log1p)
|
DEFAULT(Log1p)
|
||||||
DEFAULT(LogicalNot)
|
DEFAULT(LogicalNot)
|
||||||
|
@ -3,4 +3,5 @@ target_sources(
|
|||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
)
|
)
|
||||||
|
56
mlx/backend/io/primitives.cpp
Normal file
56
mlx/backend/io/primitives.cpp
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <const uint8_t scalar_size>
|
||||||
|
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||||
|
struct Elem {
|
||||||
|
uint8_t bytes[scalar_size];
|
||||||
|
};
|
||||||
|
|
||||||
|
Elem* data = reinterpret_cast<Elem*>(data_bytes);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < N; i++) {
|
||||||
|
for (size_t j = 0; j < (scalar_size / 2); j++) {
|
||||||
|
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Load::eval_io(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
array& out = outputs[0];
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
reader_->seek(offset_, std::ios_base::beg);
|
||||||
|
reader_->read(out.data<char>(), out.nbytes());
|
||||||
|
|
||||||
|
if (swap_endianness_) {
|
||||||
|
switch (out.itemsize()) {
|
||||||
|
case 2:
|
||||||
|
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -671,10 +671,6 @@ void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
binary_op(inputs, out, "leq");
|
binary_op(inputs, out, "leq");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
case Base::e:
|
case Base::e:
|
||||||
|
@ -61,7 +61,6 @@ NO_GPU(Greater)
|
|||||||
NO_GPU(GreaterEqual)
|
NO_GPU(GreaterEqual)
|
||||||
NO_GPU(Less)
|
NO_GPU(Less)
|
||||||
NO_GPU(LessEqual)
|
NO_GPU(LessEqual)
|
||||||
NO_GPU(Load)
|
|
||||||
NO_GPU(Log)
|
NO_GPU(Log)
|
||||||
NO_GPU(Log1p)
|
NO_GPU(Log1p)
|
||||||
NO_GPU(LogicalNot)
|
NO_GPU(LogicalNot)
|
||||||
|
10
mlx/io.h
10
mlx/io.h
@ -32,12 +32,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
|||||||
array load(std::string file, StreamOrDevice s = {});
|
array load(std::string file, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array map from .safetensors file format */
|
/** Load array map from .safetensors file format */
|
||||||
SafetensorsLoad load_safetensors(
|
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream);
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
SafetensorsLoad load_safetensors(const std::string& file);
|
||||||
StreamOrDevice s = {});
|
|
||||||
SafetensorsLoad load_safetensors(
|
|
||||||
const std::string& file,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
std::shared_ptr<io::Writer> in_stream,
|
std::shared_ptr<io::Writer> in_stream,
|
||||||
@ -50,7 +46,7 @@ void save_safetensors(
|
|||||||
|
|
||||||
/** Load array map and metadata from .gguf file format */
|
/** Load array map and metadata from .gguf file format */
|
||||||
|
|
||||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
GGUFLoad load_gguf(const std::string& file);
|
||||||
|
|
||||||
void save_gguf(
|
void save_gguf(
|
||||||
std::string file,
|
std::string file,
|
||||||
|
@ -231,7 +231,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
|||||||
return array_map;
|
return array_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
GGUFLoad load_gguf(const std::string& file) {
|
||||||
gguf_ctx* ctx = gguf_open(file.data());
|
gguf_ctx* ctx = gguf_open(file.data());
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||||
|
@ -213,7 +213,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
auto loaded_array = array(
|
auto loaded_array = array(
|
||||||
shape,
|
shape,
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
std::make_shared<Load>(
|
||||||
|
to_stream(Device::io), in_stream, offset, swap_endianness),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
if (col_contiguous) {
|
if (col_contiguous) {
|
||||||
loaded_array = transpose(loaded_array, s);
|
loaded_array = transpose(loaded_array, s);
|
||||||
|
@ -94,9 +94,7 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Load array from reader in safetensor format */
|
/** Load array from reader in safetensor format */
|
||||||
SafetensorsLoad load_safetensors(
|
SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader> in_stream) {
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Open and check file
|
// Open and check file
|
||||||
if (!in_stream->good() || !in_stream->is_open()) {
|
if (!in_stream->good() || !in_stream->is_open()) {
|
||||||
@ -138,15 +136,18 @@ SafetensorsLoad load_safetensors(
|
|||||||
shape,
|
shape,
|
||||||
type,
|
type,
|
||||||
std::make_shared<Load>(
|
std::make_shared<Load>(
|
||||||
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
to_stream(Device::io),
|
||||||
|
in_stream,
|
||||||
|
offset + data_offsets.at(0),
|
||||||
|
false),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
res.insert({item.key(), loaded_array});
|
res.insert({item.key(), loaded_array});
|
||||||
}
|
}
|
||||||
return {res, metadata_map};
|
return {res, metadata_map};
|
||||||
}
|
}
|
||||||
|
|
||||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
SafetensorsLoad load_safetensors(const std::string& file) {
|
||||||
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
return load_safetensors(std::make_shared<io::FileReader>(file));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
|
@ -162,6 +162,26 @@ class UnaryPrimitive : public Primitive {
|
|||||||
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class IOPrimitive : public Primitive {
|
||||||
|
/**
|
||||||
|
* An abstract class for primitives that are doing io which usually are not
|
||||||
|
* supposed to be evaluated on any other "device".
|
||||||
|
*/
|
||||||
|
public:
|
||||||
|
explicit IOPrimitive(Stream stream) : Primitive(stream) {}
|
||||||
|
|
||||||
|
inline void eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) override {
|
||||||
|
throw std::runtime_error("IO primitives cannot be evaluated on CPU");
|
||||||
|
}
|
||||||
|
inline void eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) override {
|
||||||
|
throw std::runtime_error("IO primitives cannot be evaluated on GPU");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class Abs : public UnaryPrimitive {
|
class Abs : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
|
explicit Abs(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
@ -1074,20 +1094,20 @@ class LessEqual : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
class Load : public UnaryPrimitive {
|
class Load : public IOPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Load(
|
explicit Load(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
std::shared_ptr<io::Reader> reader,
|
std::shared_ptr<io::Reader> reader,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
bool swap_endianness = false)
|
bool swap_endianness = false)
|
||||||
: UnaryPrimitive(stream),
|
: IOPrimitive(stream),
|
||||||
reader_(reader),
|
reader_(reader),
|
||||||
offset_(offset),
|
offset_(offset),
|
||||||
swap_endianness_(swap_endianness) {};
|
swap_endianness_(swap_endianness) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_io(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
override;
|
||||||
|
|
||||||
DEFINE_PRINT(Load)
|
DEFINE_PRINT(Load)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/backend/common/cpu_impl.h"
|
#include "mlx/backend/common/cpu_impl.h"
|
||||||
|
#include "mlx/backend/io/io_impl.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -149,6 +150,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
case Device::cpu:
|
case Device::cpu:
|
||||||
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
|
scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal));
|
||||||
break;
|
break;
|
||||||
|
case Device::io:
|
||||||
|
scheduler::enqueue(stream, io::make_task(std::move(arr), signal));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return synchronizer;
|
return synchronizer;
|
||||||
|
@ -162,10 +162,10 @@ std::pair<
|
|||||||
std::unordered_map<std::string, std::string>>
|
std::unordered_map<std::string, std::string>>
|
||||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||||
return load_safetensors(nb::cast<std::string>(file), s);
|
return load_safetensors(nb::cast<std::string>(file));
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
auto res = load_safetensors(std::make_shared<PyFileReader>(file));
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release gil;
|
nb::gil_scoped_release gil;
|
||||||
for (auto& [key, arr] : std::get<0>(res)) {
|
for (auto& [key, arr] : std::get<0>(res)) {
|
||||||
@ -181,7 +181,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
|||||||
|
|
||||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
||||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||||
return load_gguf(nb::cast<std::string>(file), s);
|
return load_gguf(nb::cast<std::string>(file));
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||||
|
Loading…
Reference in New Issue
Block a user