From 42baa095d1f7659c24d2cb5de75464faa834f2b0 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Mon, 18 Dec 2023 14:33:23 -0500 Subject: [PATCH] its working --- mlx/backend/common/load.cpp | 8 ++++++-- mlx/load.cpp | 3 ++- mlx/primitives.h | 2 ++ mlx/safetensor.cpp | 16 +++++++++++++++- tests/load_tests.cpp | 4 ++++ 5 files changed, 29 insertions(+), 4 deletions(-) diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index e68ce7f6f..a00118872 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -31,10 +31,14 @@ void swap_endianess(uint8_t* data_bytes, size_t N) { void Load::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + if (len_ == 0) { + len_ = out.nbytes(); + } + printf("Load::eval: offset= %ld len_ = %ld\n", offset_, len_); + out.set_data(allocator::malloc_or_wait(len_)); reader_->seek(offset_, std::ios_base::beg); - reader_->read(out.data(), out.nbytes()); + reader_->read(out.data(), len_); if (swap_endianness_) { switch (out.itemsize()) { diff --git a/mlx/load.cpp b/mlx/load.cpp index 8106448a4..fdf5f6f17 100644 --- a/mlx/load.cpp +++ b/mlx/load.cpp @@ -225,7 +225,8 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { auto loaded_array = array( shape, dtype, - std::make_unique(to_stream(s), in_stream, offset, swap_endianness), + std::make_unique( + to_stream(s), in_stream, offset, in_stream->tell(), swap_endianness), std::vector{}); if (col_contiguous) { loaded_array = transpose(loaded_array, s); diff --git a/mlx/primitives.h b/mlx/primitives.h index 0cb98c9c7..c993fd6d7 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -826,6 +826,7 @@ class Load : public Primitive { Stream stream, std::shared_ptr reader, size_t offset, + size_t len, bool swap_endianness = false) : Primitive(stream), reader_(reader), @@ -841,6 +842,7 @@ class Load : public Primitive { void eval(const std::vector& inputs, array& out); std::shared_ptr reader_; size_t offset_; + size_t len_; bool swap_endianness_; }; diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 072e46177..a6c0ce69d 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -189,6 +189,7 @@ std::unordered_map load_safetensor( throw std::runtime_error( "[load_safetensor] Invalid json metadata " + in_stream->label()); } + size_t offset = jsonHeaderLength + 8; // Parse the json raw data std::unordered_map res; for (auto& [key, obj] : *metadata.getObject()) { @@ -204,9 +205,22 @@ std::unordered_map load_safetensor( data_offsets_vec.push_back(offset->getNumber()); } if (dtype == "F32") { - res.insert({key, zeros(shape_vec, s)}); + auto loaded_array = array( + shape_vec, + float32, + std::make_unique( + to_stream(s), + in_stream, + offset + data_offsets->at(0)->getNumber(), + offset + data_offsets->at(1)->getNumber(), + false), + std::vector{}); + res.insert({key, loaded_array}); } } + // for (auto& [key, arr] : res) { + // arr.eval(); + // } return res; } diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 198fcddd0..35ebe0533 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -145,6 +145,10 @@ TEST_CASE("test load_safetensor") { array test = safeDict.at("test"); CHECK_EQ(test.dtype(), float32); CHECK_EQ(test.shape(), std::vector({4})); + array b = array({1.0, 2.0, 3.0, 4.0}); + MESSAGE("test: " << test); + MESSAGE("b: " << b); + CHECK(array_equal(test, b).item()); } TEST_CASE("test single array serialization") {