From 1f6ab6a556045961c639735efceebbee7cce814d Mon Sep 17 00:00:00 2001 From: Diogo Date: Wed, 27 Dec 2023 05:06:55 -0500 Subject: [PATCH] Safetensor support (#215) Co-authored-by: Awni Hannun --- .gitignore | 4 + CMakeLists.txt | 11 +++ docs/src/python/ops.rst | 1 + mlx/CMakeLists.txt | 3 +- mlx/backend/common/load.cpp | 2 +- mlx/io/CMakeLists.txt | 6 ++ mlx/{ => io}/load.cpp | 2 +- mlx/{ => io}/load.h | 0 mlx/io/safetensor.cpp | 189 ++++++++++++++++++++++++++++++++++++ mlx/io/safetensor.h | 32 ++++++ mlx/ops.h | 18 +++- mlx/primitives.h | 2 +- python/src/load.cpp | 151 +++++++++++++++++++++------- python/src/load.h | 15 ++- python/src/ops.cpp | 45 +++++++-- python/tests/test_load.py | 27 ++++++ tests/load_tests.cpp | 20 ++++ 17 files changed, 476 insertions(+), 52 deletions(-) create mode 100644 mlx/io/CMakeLists.txt rename mlx/{ => io}/load.cpp (99%) rename mlx/{ => io}/load.h (100%) create mode 100644 mlx/io/safetensor.cpp create mode 100644 mlx/io/safetensor.h diff --git a/.gitignore b/.gitignore index 3a30aae9e..8dfe5038e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,10 @@ __pycache__/ # C extensions *.so +# tensor files +*.safe +*.safetensors + # Metal libraries *.metallib venv/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 70293ebba..f79402bcc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,15 @@ elseif (MLX_BUILD_METAL) ${QUARTZ_LIB}) endif() +MESSAGE(STATUS "Downloading json") +FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +FetchContent_MakeAvailable(json) +target_include_directories( + mlx PUBLIC + $ + $ +) + find_library(ACCELERATE_LIBRARY Accelerate) if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") @@ -152,6 +161,8 @@ if (MLX_BUILD_BENCHMARKS) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) endif() + + # ----------------------------- Installation ----------------------------- include(GNUInstallDirs) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 7e391ec4c..0c5763290 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -83,6 +83,7 @@ Operations save savez savez_compressed + save_safetensors sigmoid sign sin diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index e004fc3d9..882bf93e0 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -8,7 +8,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp @@ -19,7 +18,7 @@ target_sources( ) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) - +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if (MLX_BUILD_ACCELERATE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) else() diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index e68ce7f6f..6cf8ffe53 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -5,7 +5,7 @@ #include #include "mlx/allocator.h" -#include "mlx/load.h" +#include "mlx/io/load.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/io/CMakeLists.txt b/mlx/io/CMakeLists.txt new file mode 100644 index 000000000..f3e27b96a --- /dev/null +++ b/mlx/io/CMakeLists.txt @@ -0,0 +1,6 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp +) diff --git a/mlx/load.cpp b/mlx/io/load.cpp similarity index 99% rename from mlx/load.cpp rename to mlx/io/load.cpp index 8106448a4..856cf17a2 100644 --- a/mlx/load.cpp +++ b/mlx/io/load.cpp @@ -6,7 +6,7 @@ #include #include -#include "mlx/load.h" +#include "mlx/io/load.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/load.h b/mlx/io/load.h similarity index 100% rename from mlx/load.h rename to mlx/io/load.h diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp new file mode 100644 index 000000000..a690e6420 --- /dev/null +++ b/mlx/io/safetensor.cpp @@ -0,0 +1,189 @@ +#include "mlx/io/safetensor.h" + +#include + +namespace mlx::core { + +std::string dtype_to_safetensor_str(Dtype t) { + switch (t) { + case float32: + return ST_F32; + case bfloat16: + return ST_BF16; + case float16: + return ST_F16; + case int64: + return ST_I64; + case int32: + return ST_I32; + case int16: + return ST_I16; + case int8: + return ST_I8; + case uint64: + return ST_U64; + case uint32: + return ST_U32; + case uint16: + return ST_U16; + case uint8: + return ST_U8; + case bool_: + return ST_BOOL; + case complex64: + return ST_C64; + } +} + +Dtype dtype_from_safetensor_str(std::string str) { + if (str == ST_F32) { + return float32; + } else if (str == ST_F16) { + return float16; + } else if (str == ST_BF16) { + return bfloat16; + } else if (str == ST_I64) { + return int64; + } else if (str == ST_I32) { + return int32; + } else if (str == ST_I16) { + return int16; + } else if (str == ST_I8) { + return int8; + } else if (str == ST_U64) { + return uint64; + } else if (str == ST_U32) { + return uint32; + } else if (str == ST_U16) { + return uint16; + } else if (str == ST_U8) { + return uint8; + } else if (str == ST_BOOL) { + return bool_; + } else if (str == ST_C64) { + return complex64; + } else { + throw std::runtime_error("[safetensor] unsupported dtype " + str); + } +} + +/** Load array from reader in safetensor format */ +std::unordered_map load_safetensors( + std::shared_ptr in_stream, + StreamOrDevice s) { + //////////////////////////////////////////////////////// + // Open and check file + if (!in_stream->good() || !in_stream->is_open()) { + throw std::runtime_error( + "[load_safetensors] Failed to open " + in_stream->label()); + } + + uint64_t jsonHeaderLength = 0; + in_stream->read(reinterpret_cast(&jsonHeaderLength), 8); + if (jsonHeaderLength <= 0) { + throw std::runtime_error( + "[load_safetensors] Invalid json header length " + in_stream->label()); + } + // Load the json metadata + char rawJson[jsonHeaderLength]; + in_stream->read(rawJson, jsonHeaderLength); + auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength); + // Should always be an object on the top-level + if (!metadata.is_object()) { + throw std::runtime_error( + "[load_safetensors] Invalid json metadata " + in_stream->label()); + } + size_t offset = jsonHeaderLength + 8; + // Load the arrays using metadata + std::unordered_map res; + for (const auto& item : metadata.items()) { + if (item.key() == "__metadata__") { + // ignore metadata for now + continue; + } + std::string dtype = item.value().at("dtype"); + std::vector shape = item.value().at("shape"); + std::vector data_offsets = item.value().at("data_offsets"); + Dtype type = dtype_from_safetensor_str(dtype); + auto loaded_array = array( + shape, + type, + std::make_unique( + to_stream(s), in_stream, offset + data_offsets.at(0), false), + std::vector{}); + res.insert({item.key(), loaded_array}); + } + return res; +} + +std::unordered_map load_safetensors( + const std::string& file, + StreamOrDevice s) { + return load_safetensors(std::make_shared(file), s); +} + +/** Save array to out stream in .npy format */ +void save_safetensors( + std::shared_ptr out_stream, + std::unordered_map a, + std::optional retain_graph_) { + //////////////////////////////////////////////////////// + // Check file + if (!out_stream->good() || !out_stream->is_open()) { + throw std::runtime_error( + "[save_safetensors] Failed to open " + out_stream->label()); + } + + //////////////////////////////////////////////////////// + // Check array map + json parent; + parent["__metadata__"] = json::object({ + {"format", "mlx"}, + }); + size_t offset = 0; + for (auto& [key, arr] : a) { + arr.eval(retain_graph_.value_or(arr.is_tracer())); + if (arr.nbytes() == 0) { + throw std::invalid_argument( + "[save_safetensors] cannot serialize an empty array key: " + key); + } + + if (!arr.flags().contiguous) { + throw std::invalid_argument( + "[save_safetensors] cannot serialize a non-contiguous array key: " + + key); + } + json child; + child["dtype"] = dtype_to_safetensor_str(arr.dtype()); + child["shape"] = arr.shape(); + child["data_offsets"] = std::vector{offset, offset + arr.nbytes()}; + parent[key] = child; + offset += arr.nbytes(); + } + + auto header = parent.dump(); + uint64_t header_len = header.length(); + out_stream->write(reinterpret_cast(&header_len), 8); + out_stream->write(header.c_str(), header_len); + for (auto& [key, arr] : a) { + out_stream->write(arr.data(), arr.nbytes()); + } +} + +void save_safetensors( + const std::string& file_, + std::unordered_map a, + std::optional retain_graph) { + // Open and check file + std::string file = file_; + + // Add .safetensors to file name if it is not there + if (file.length() < 12 || + file.substr(file.length() - 12, 12) != ".safetensors") + file += ".safetensors"; + + // Serialize array + save_safetensors(std::make_shared(file), a, retain_graph); +} + +} // namespace mlx::core diff --git a/mlx/io/safetensor.h b/mlx/io/safetensor.h new file mode 100644 index 000000000..104a226ce --- /dev/null +++ b/mlx/io/safetensor.h @@ -0,0 +1,32 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/io/load.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" + +using json = nlohmann::json; + +namespace mlx::core { + +#define ST_F16 "F16" +#define ST_BF16 "BF16" +#define ST_F32 "F32" + +#define ST_BOOL "BOOL" +#define ST_I8 "I8" +#define ST_I16 "I16" +#define ST_I32 "I32" +#define ST_I64 "I64" +#define ST_U8 "U8" +#define ST_U16 "U16" +#define ST_U32 "U32" +#define ST_U64 "U64" + +// Note: Complex numbers aren't in the spec yet so this could change - +// https://github.com/huggingface/safetensors/issues/389 +#define ST_C64 "C64" +} // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index fe59d4e49..e1abac6fb 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -7,7 +7,7 @@ #include "array.h" #include "device.h" -#include "load.h" +#include "io/load.h" #include "stream.h" namespace mlx::core { @@ -1057,4 +1057,20 @@ array dequantize( int bits = 4, StreamOrDevice s = {}); +/** Load array map from .safetensors file format */ +std::unordered_map load_safetensors( + std::shared_ptr in_stream, + StreamOrDevice s = {}); +std::unordered_map load_safetensors( + const std::string& file, + StreamOrDevice s = {}); + +void save_safetensors( + std::shared_ptr in_stream, + std::unordered_map, + std::optional retain_graph = std::nullopt); +void save_safetensors( + const std::string& file, + std::unordered_map, + std::optional retain_graph = std::nullopt); } // namespace mlx::core diff --git a/mlx/primitives.h b/mlx/primitives.h index 0cb98c9c7..747b26c10 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -4,7 +4,7 @@ #include "array.h" #include "device.h" -#include "load.h" +#include "io/load.h" #include "stream.h" #define DEFINE_GRADS() \ diff --git a/python/src/load.cpp b/python/src/load.cpp index 1a52930b2..a63e5063e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -6,12 +6,11 @@ #include #include #include -#include #include #include #include -#include "mlx/load.h" +#include "mlx/io/load.h" #include "mlx/ops.h" #include "mlx/utils.h" #include "python/src/load.h" @@ -161,40 +160,68 @@ class PyFileReader : public io::Reader { py::object tell_func_; }; -DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { - py::module_ zipfile = py::module_::import("zipfile"); - - // Assume .npz file if it is zipped - if (is_zip_file(zipfile, file)) { - // Output dictionary filename in zip -> loaded array - std::unordered_map array_dict; - - // Create python ZipFile object - ZipFileWrapper zipfile_object(zipfile, file); - for (const std::string& st : zipfile_object.namelist()) { - // Open zip file as a python file stream - py::object sub_file = zipfile_object.open(st); - - // Create array from python fille stream - auto arr = load(std::make_shared(sub_file), s); - - // Remove .npy from file if it is there - auto key = st; - if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") - key = st.substr(0, st.length() - 4); - - // Add array to dict - array_dict.insert({key, arr}); - } - +std::unordered_map mlx_load_safetensor_helper( + py::object file, + StreamOrDevice s) { + if (py::isinstance(file)) { // Assume .safetensors file path string + return {load_safetensors(py::cast(file), s)}; + } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - for (auto& [key, arr] : array_dict) { + auto arr = load_safetensors(std::make_shared(file), s); + { py::gil_scoped_release gil; - arr.eval(); + for (auto& [key, arr] : arr) { + arr.eval(); + } } + return {arr}; + } - return {array_dict}; - } else if (py::isinstance(file)) { // Assume .npy file path string + throw std::invalid_argument( + "[load_safetensors] Input must be a file-like object, or string"); +} + +std::unordered_map mlx_load_npz_helper( + py::object file, + StreamOrDevice s) { + py::module_ zipfile = py::module_::import("zipfile"); + if (!is_zip_file(zipfile, file)) { + throw std::invalid_argument( + "[load_npz] Input must be a zip file or a file-like object that can be " + "opened with zipfile.ZipFile"); + } + // Output dictionary filename in zip -> loaded array + std::unordered_map array_dict; + + // Create python ZipFile object + ZipFileWrapper zipfile_object(zipfile, file); + for (const std::string& st : zipfile_object.namelist()) { + // Open zip file as a python file stream + py::object sub_file = zipfile_object.open(st); + + // Create array from python fille stream + auto arr = load(std::make_shared(sub_file), s); + + // Remove .npy from file if it is there + auto key = st; + if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") + key = st.substr(0, st.length() - 4); + + // Add array to dict + array_dict.insert({key, arr}); + } + + // If we don't own the stream and it was passed to us, eval immediately + for (auto& [key, arr] : array_dict) { + py::gil_scoped_release gil; + arr.eval(); + } + + return {array_dict}; +} + +array mlx_load_npy_helper(py::object file, StreamOrDevice s) { + if (py::isinstance(file)) { // Assume .npy file path string return {load(py::cast(file), s)}; } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately @@ -205,9 +232,41 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { } return {arr}; } - throw std::invalid_argument( - "[load] Input must be a file-like object, string, or pathlib.Path"); + "[load_npy] Input must be a file-like object, or string"); +} + +DictOrArray mlx_load_helper( + py::object file, + std::optional format, + StreamOrDevice s) { + if (!format.has_value()) { + std::string fname; + if (py::isinstance(file)) { + fname = py::cast(file); + } else if (is_istream_object(file)) { + fname = file.attr("name").cast(); + } else { + throw std::invalid_argument( + "[load] Input must be a file-like object, or string"); + } + size_t ext = fname.find_last_of('.'); + if (ext == std::string::npos) { + throw std::invalid_argument( + "[load] Could not infer file format from extension"); + } + format.emplace(fname.substr(ext + 1)); + } + + if (format.value() == "safetensors") { + return mlx_load_safetensor_helper(file, s); + } else if (format.value() == "npz") { + return mlx_load_npz_helper(file, s); + } else if (format.value() == "npy") { + return mlx_load_npy_helper(file, s); + } else { + throw std::invalid_argument("[load] Unknown file format " + format.value()); + } } /////////////////////////////////////////////////////////////////////////////// @@ -305,7 +364,7 @@ void mlx_save_helper( } throw std::invalid_argument( - "[save] Input must be a file-like object, string, or pathlib.Path"); + "[save] Input must be a file-like object, or string"); } void mlx_savez_helper( @@ -361,3 +420,25 @@ void mlx_savez_helper( return; } + +void mlx_save_safetensor_helper( + py::object file, + py::dict d, + std::optional retain_graph) { + auto arrays_map = d.cast>(); + if (py::isinstance(file)) { + save_safetensors(py::cast(file), arrays_map, retain_graph); + return; + } else if (is_ostream_object(file)) { + auto writer = std::make_shared(file); + { + py::gil_scoped_release gil; + save_safetensors(writer, arrays_map, retain_graph); + } + + return; + } + + throw std::invalid_argument( + "[save_safetensors] Input must be a file-like object, or string"); +} diff --git a/python/src/load.h b/python/src/load.h index 8f64a64d1..4dc6fcda7 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include #include #include "mlx/ops.h" @@ -12,7 +14,18 @@ using namespace mlx::core; using DictOrArray = std::variant>; -DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); +std::unordered_map mlx_load_safetensor_helper( + py::object file, + StreamOrDevice s); +void mlx_save_safetensor_helper( + py::object file, + py::dict d, + std::optional retain_graph = std::nullopt); + +DictOrArray mlx_load_helper( + py::object file, + std::optional format, + StreamOrDevice s); void mlx_save_helper( py::object file, array a, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 277ef596b..f97a55bce 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) { Args: file (str): File to which the array is saved arr (array): Array to be saved. - retain_graph (bool, optional): Optional argument to retain graph - during array evaluation before saving. If not provided the graph - is retained if we are during a function transformation. Default: - None - + retain_graph (bool, optional): Whether or not to retain the graph + during array evaluation. If left unspecified the graph is retained + only if saving is done in a function transformation. Default: ``None`` )pbdoc"); m.def( "savez", @@ -2932,18 +2930,45 @@ void init_ops(py::module_& m) { &mlx_load_helper, "file"_a, py::pos_only(), + "format"_a = none, py::kw_only(), "stream"_a = none, R"pbdoc( - load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] + load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] - Load array(s) from a binary file in ``.npy`` or ``.npz`` format. + Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format. Args: - file (file, str): File in which the array is saved - + file (file, str): File in which the array is saved. + format (str, optional): Format of the file. If ``None``, the format + is inferred from the file extension. Supported formats: ``npy``, + ``npz``, and ``safetensors``. Default: ``None``. Returns: - result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file + result (array, dict): + A single array if loading from a ``.npy`` file or a dict mapping + names to arrays if loading from a ``.npz`` or ``.safetensors`` file. + )pbdoc"); + m.def( + "save_safetensors", + &mlx_save_safetensor_helper, + "file"_a, + "arrays"_a, + py::pos_only(), + "retain_graph"_a = std::nullopt, + py::kw_only(), + R"pbdoc( + save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None) + + Save array(s) to a binary file in ``.safetensors`` format. + + For more information on the format see https://huggingface.co/docs/safetensors/index. + + Args: + file (file, str): File in which the array is saved> + arrays (dict(str, array)): The dictionary of names to arrays to be saved. + retain_graph (bool, optional): Whether or not to retain the graph + during array evaluation. If left unspecified the graph is retained + only if saving is done in a function transformation. Default: ``None``. )pbdoc"); m.def( "where", diff --git a/python/tests/test_load.py b/python/tests/test_load.py index e63588d03..2ee550604 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase): load_arr_mlx_npy = np.load(save_file_mlx) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) + def test_save_and_load_safetensors(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + for dt in self.dtypes + ["bfloat16"]: + with self.subTest(dtype=dt): + for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): + with self.subTest(shape=shape): + save_file_mlx = os.path.join( + self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" + ) + save_dict = { + "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + } + + with open(save_file_mlx, "wb") as f: + mx.save_safetensors(f, save_dict) + with open(save_file_mlx, "rb") as f: + load_dict = mx.load(f) + + self.assertTrue("test" in load_dict) + self.assertTrue( + mx.array_equal(load_dict["test"], save_dict["test"]) + ) + def test_save_and_load_fs(self): if not os.path.isdir(self.test_dir): diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index f2489ca72..edff1aff6 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -14,6 +14,26 @@ std::string get_temp_file(const std::string& name) { return std::filesystem::temp_directory_path().append(name); } +TEST_CASE("test save_safetensors") { + std::string file_path = get_temp_file("test_arr.safetensors"); + auto map = std::unordered_map(); + map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); + map.insert({"test2", ones({2, 2})}); + save_safetensors(file_path, map); + auto safeDict = load_safetensors(file_path); + CHECK_EQ(safeDict.size(), 2); + CHECK_EQ(safeDict.count("test"), 1); + CHECK_EQ(safeDict.count("test2"), 1); + array test = safeDict.at("test"); + CHECK_EQ(test.dtype(), float32); + CHECK_EQ(test.shape(), std::vector({4})); + CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item()); + array test2 = safeDict.at("test2"); + CHECK_EQ(test2.dtype(), float32); + CHECK_EQ(test2.shape(), std::vector({2, 2})); + CHECK(array_equal(test2, ones({2, 2})).item()); +} + TEST_CASE("test single array serialization") { // Basic test {