From f4d876d35fe3b401755506c9b7b3e97313ef489d Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Mon, 18 Dec 2023 17:11:25 -0500 Subject: [PATCH] load changes were not needed after all --- mlx/backend/common/load.cpp | 7 ++----- mlx/load.cpp | 3 +-- mlx/primitives.h | 2 -- mlx/safetensor.cpp | 1 - tests/load_tests.cpp | 1 - 5 files changed, 3 insertions(+), 11 deletions(-) diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index d54774dc0..e68ce7f6f 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -31,13 +31,10 @@ void swap_endianess(uint8_t* data_bytes, size_t N) { void Load::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 0); - if (len_ == 0) { - len_ = out.nbytes(); - } - out.set_data(allocator::malloc_or_wait(len_)); + out.set_data(allocator::malloc_or_wait(out.nbytes())); reader_->seek(offset_, std::ios_base::beg); - reader_->read(out.data(), len_); + reader_->read(out.data(), out.nbytes()); if (swap_endianness_) { switch (out.itemsize()) { diff --git a/mlx/load.cpp b/mlx/load.cpp index fdf5f6f17..8106448a4 100644 --- a/mlx/load.cpp +++ b/mlx/load.cpp @@ -225,8 +225,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { auto loaded_array = array( shape, dtype, - std::make_unique( - to_stream(s), in_stream, offset, in_stream->tell(), swap_endianness), + std::make_unique(to_stream(s), in_stream, offset, swap_endianness), std::vector{}); if (col_contiguous) { loaded_array = transpose(loaded_array, s); diff --git a/mlx/primitives.h b/mlx/primitives.h index c993fd6d7..0cb98c9c7 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -826,7 +826,6 @@ class Load : public Primitive { Stream stream, std::shared_ptr reader, size_t offset, - size_t len, bool swap_endianness = false) : Primitive(stream), reader_(reader), @@ -842,7 +841,6 @@ class Load : public Primitive { void eval(const std::vector& inputs, array& out); std::shared_ptr reader_; size_t offset_; - size_t len_; bool swap_endianness_; }; diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 85b69ada6..208829d7f 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -307,7 +307,6 @@ std::unordered_map load_safetensor( to_stream(s), in_stream, offset + data_offsets->at(0)->getNumber(), - offset + data_offsets->at(1)->getNumber(), false), std::vector{}); res.insert({key, loaded_array}); diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index cac6999bd..60e06d670 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -165,7 +165,6 @@ TEST_CASE("test save_safetensor") { auto map = std::unordered_map(); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test2", ones({2, 2})}); - MESSAGE("SAVING"); save_safetensor(file_path, map); auto safeDict = load_safetensor(file_path); CHECK_EQ(safeDict.size(), 2);