diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 7b48e62f7..912a48a36 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -51,7 +51,6 @@ DEFAULT(Greater) DEFAULT(GreaterEqual) DEFAULT(Less) DEFAULT(LessEqual) -DEFAULT(Load) DEFAULT(LogicalNot) DEFAULT(LogicalAnd) DEFAULT(LogicalOr) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index d8ec303f1..f35c942b9 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -68,7 +68,6 @@ DEFAULT(Greater) DEFAULT(GreaterEqual) DEFAULT(Less) DEFAULT(LessEqual) -DEFAULT(Load) DEFAULT(Log) DEFAULT(Log1p) DEFAULT(LogicalNot) diff --git a/mlx/backend/io/CMakeLists.txt b/mlx/backend/io/CMakeLists.txt index 69eaebfb5..752faea72 100644 --- a/mlx/backend/io/CMakeLists.txt +++ b/mlx/backend/io/CMakeLists.txt @@ -3,4 +3,5 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ) diff --git a/mlx/backend/io/primitives.cpp b/mlx/backend/io/primitives.cpp new file mode 100644 index 000000000..b8569a7d0 --- /dev/null +++ b/mlx/backend/io/primitives.cpp @@ -0,0 +1,56 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void swap_endianness(uint8_t* data_bytes, size_t N) { + struct Elem { + uint8_t bytes[scalar_size]; + }; + + Elem* data = reinterpret_cast(data_bytes); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < (scalar_size / 2); j++) { + std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); + } + } +} + +} // namespace + +void Load::eval_io( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 0); + array& out = outputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + reader_->seek(offset_, std::ios_base::beg); + reader_->read(out.data(), out.nbytes()); + + if (swap_endianness_) { + switch (out.itemsize()) { + case 2: + swap_endianness<2>(out.data(), out.data_size()); + break; + case 4: + swap_endianness<4>(out.data(), out.data_size()); + break; + case 8: + swap_endianness<8>(out.data(), out.data_size()); + break; + } + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 364132eba..33e242b10 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -671,10 +671,6 @@ void LessEqual::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "leq"); } -void Load::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - void Log::eval_gpu(const std::vector& inputs, array& out) { switch (base_) { case Base::e: diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 2b10e416a..bd2ee79ad 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -61,7 +61,6 @@ NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Less) NO_GPU(LessEqual) -NO_GPU(Load) NO_GPU(Log) NO_GPU(Log1p) NO_GPU(LogicalNot) diff --git a/mlx/io.h b/mlx/io.h index e30c0de34..8ec71acb8 100644 --- a/mlx/io.h +++ b/mlx/io.h @@ -32,12 +32,8 @@ array load(std::shared_ptr in_stream, StreamOrDevice s = {}); array load(std::string file, StreamOrDevice s = {}); /** Load array map from .safetensors file format */ -SafetensorsLoad load_safetensors( - std::shared_ptr in_stream, - StreamOrDevice s = {}); -SafetensorsLoad load_safetensors( - const std::string& file, - StreamOrDevice s = {}); +SafetensorsLoad load_safetensors(std::shared_ptr in_stream); +SafetensorsLoad load_safetensors(const std::string& file); void save_safetensors( std::shared_ptr in_stream, @@ -50,7 +46,7 @@ void save_safetensors( /** Load array map and metadata from .gguf file format */ -GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); +GGUFLoad load_gguf(const std::string& file); void save_gguf( std::string file, diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 0193d2d09..15c5bd9d9 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -231,7 +231,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { return array_map; } -GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { +GGUFLoad load_gguf(const std::string& file) { gguf_ctx* ctx = gguf_open(file.data()); if (!ctx) { throw std::runtime_error("[load_gguf] gguf_init failed"); diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 294a1229f..a273476cd 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -213,7 +213,8 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { auto loaded_array = array( shape, dtype, - std::make_shared(to_stream(s), in_stream, offset, swap_endianness), + std::make_shared( + to_stream(Device::io), in_stream, offset, swap_endianness), std::vector{}); if (col_contiguous) { loaded_array = transpose(loaded_array, s); diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 69ebd46c8..46715a9f4 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -94,9 +94,7 @@ Dtype dtype_from_safetensor_str(std::string_view str) { } /** Load array from reader in safetensor format */ -SafetensorsLoad load_safetensors( - std::shared_ptr in_stream, - StreamOrDevice s) { +SafetensorsLoad load_safetensors(std::shared_ptr in_stream) { //////////////////////////////////////////////////////// // Open and check file if (!in_stream->good() || !in_stream->is_open()) { @@ -138,15 +136,18 @@ SafetensorsLoad load_safetensors( shape, type, std::make_shared( - to_stream(s), in_stream, offset + data_offsets.at(0), false), + to_stream(Device::io), + in_stream, + offset + data_offsets.at(0), + false), std::vector{}); res.insert({item.key(), loaded_array}); } return {res, metadata_map}; } -SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { - return load_safetensors(std::make_shared(file), s); +SafetensorsLoad load_safetensors(const std::string& file) { + return load_safetensors(std::make_shared(file)); } /** Save array to out stream in .npy format */ diff --git a/mlx/primitives.h b/mlx/primitives.h index 6bc31a31b..b1551eafb 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -162,6 +162,26 @@ class UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; +class IOPrimitive : public Primitive { + /** + * An abstract class for primitives that are doing io which usually are not + * supposed to be evaluated on any other "device". + */ + public: + explicit IOPrimitive(Stream stream) : Primitive(stream) {} + + inline void eval_cpu( + const std::vector& inputs, + std::vector& outputs) override { + throw std::runtime_error("IO primitives cannot be evaluated on CPU"); + } + inline void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override { + throw std::runtime_error("IO primitives cannot be evaluated on GPU"); + } +}; + class Abs : public UnaryPrimitive { public: explicit Abs(Stream stream) : UnaryPrimitive(stream) {}; @@ -1074,20 +1094,20 @@ class LessEqual : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; -class Load : public UnaryPrimitive { +class Load : public IOPrimitive { public: explicit Load( Stream stream, std::shared_ptr reader, size_t offset, bool swap_endianness = false) - : UnaryPrimitive(stream), + : IOPrimitive(stream), reader_(reader), offset_(offset), swap_endianness_(swap_endianness) {}; - void eval_cpu(const std::vector& inputs, array& out) override; - void eval_gpu(const std::vector& inputs, array& out) override; + void eval_io(const std::vector& inputs, std::vector& outputs) + override; DEFINE_PRINT(Load) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index c93a42fd1..bb06b1b23 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -9,6 +9,7 @@ #include #include "mlx/backend/common/cpu_impl.h" +#include "mlx/backend/io/io_impl.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/ops.h" #include "mlx/primitives.h" @@ -149,6 +150,9 @@ array eval_impl(std::vector outputs, bool async) { case Device::cpu: scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal)); break; + case Device::io: + scheduler::enqueue(stream, io::make_task(std::move(arr), signal)); + break; } } return synchronizer; diff --git a/python/src/load.cpp b/python/src/load.cpp index efad3d97d..93b4df40a 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -162,10 +162,10 @@ std::pair< std::unordered_map> mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) { if (nb::isinstance(file)) { // Assume .safetensors file path string - return load_safetensors(nb::cast(file), s); + return load_safetensors(nb::cast(file)); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - auto res = load_safetensors(std::make_shared(file), s); + auto res = load_safetensors(std::make_shared(file)); { nb::gil_scoped_release gil; for (auto& [key, arr] : std::get<0>(res)) { @@ -181,7 +181,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) { GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) { if (nb::isinstance(file)) { // Assume .gguf file path string - return load_gguf(nb::cast(file), s); + return load_gguf(nb::cast(file)); } throw std::invalid_argument("[load_gguf] Input must be a string");