Safetensor support (#215)

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo 2023-12-27 05:06:55 -05:00 committed by GitHub
parent 6b0d30bb85
commit 1f6ab6a556
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 476 additions and 52 deletions

4
.gitignore vendored
View File

@ -6,6 +6,10 @@ __pycache__/
# C extensions # C extensions
*.so *.so
# tensor files
*.safe
*.safetensors
# Metal libraries # Metal libraries
*.metallib *.metallib
venv/ venv/

View File

@ -98,6 +98,15 @@ elseif (MLX_BUILD_METAL)
${QUARTZ_LIB}) ${QUARTZ_LIB})
endif() endif()
MESSAGE(STATUS "Downloading json")
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PUBLIC
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
$<INSTALL_INTERFACE:include/json>
)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
@ -152,6 +161,8 @@ if (MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif() endif()
# ----------------------------- Installation ----------------------------- # ----------------------------- Installation -----------------------------
include(GNUInstallDirs) include(GNUInstallDirs)

View File

@ -83,6 +83,7 @@ Operations
save save
savez savez
savez_compressed savez_compressed
save_safetensors
sigmoid sigmoid
sign sign
sin sin

View File

@ -8,7 +8,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
@ -19,7 +18,7 @@ target_sources(
) )
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE) if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else() else()

View File

@ -5,7 +5,7 @@
#include <utility> #include <utility>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/load.h" #include "mlx/io/load.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {

6
mlx/io/CMakeLists.txt Normal file
View File

@ -0,0 +1,6 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
)

View File

@ -6,7 +6,7 @@
#include <limits> #include <limits>
#include <sstream> #include <sstream>
#include "mlx/load.h" #include "mlx/io/load.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"

189
mlx/io/safetensor.cpp Normal file
View File

@ -0,0 +1,189 @@
#include "mlx/io/safetensor.h"
#include <stack>
namespace mlx::core {
std::string dtype_to_safetensor_str(Dtype t) {
switch (t) {
case float32:
return ST_F32;
case bfloat16:
return ST_BF16;
case float16:
return ST_F16;
case int64:
return ST_I64;
case int32:
return ST_I32;
case int16:
return ST_I16;
case int8:
return ST_I8;
case uint64:
return ST_U64;
case uint32:
return ST_U32;
case uint16:
return ST_U16;
case uint8:
return ST_U8;
case bool_:
return ST_BOOL;
case complex64:
return ST_C64;
}
}
Dtype dtype_from_safetensor_str(std::string str) {
if (str == ST_F32) {
return float32;
} else if (str == ST_F16) {
return float16;
} else if (str == ST_BF16) {
return bfloat16;
} else if (str == ST_I64) {
return int64;
} else if (str == ST_I32) {
return int32;
} else if (str == ST_I16) {
return int16;
} else if (str == ST_I8) {
return int8;
} else if (str == ST_U64) {
return uint64;
} else if (str == ST_U32) {
return uint32;
} else if (str == ST_U16) {
return uint16;
} else if (str == ST_U8) {
return uint8;
} else if (str == ST_BOOL) {
return bool_;
} else if (str == ST_C64) {
return complex64;
} else {
throw std::runtime_error("[safetensor] unsupported dtype " + str);
}
}
/** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error(
"[load_safetensors] Failed to open " + in_stream->label());
}
uint64_t jsonHeaderLength = 0;
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0) {
throw std::runtime_error(
"[load_safetensors] Invalid json header length " + in_stream->label());
}
// Load the json metadata
char rawJson[jsonHeaderLength];
in_stream->read(rawJson, jsonHeaderLength);
auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength);
// Should always be an object on the top-level
if (!metadata.is_object()) {
throw std::runtime_error(
"[load_safetensors] Invalid json metadata " + in_stream->label());
}
size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata
std::unordered_map<std::string, array> res;
for (const auto& item : metadata.items()) {
if (item.key() == "__metadata__") {
// ignore metadata for now
continue;
}
std::string dtype = item.value().at("dtype");
std::vector<int> shape = item.value().at("shape");
std::vector<size_t> data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype);
auto loaded_array = array(
shape,
type,
std::make_unique<Load>(
to_stream(s), in_stream, offset + data_offsets.at(0), false),
std::vector<array>{});
res.insert({item.key(), loaded_array});
}
return res;
}
std::unordered_map<std::string, array> load_safetensors(
const std::string& file,
StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
}
/** Save array to out stream in .npy format */
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph_) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error(
"[save_safetensors] Failed to open " + out_stream->label());
}
////////////////////////////////////////////////////////
// Check array map
json parent;
parent["__metadata__"] = json::object({
{"format", "mlx"},
});
size_t offset = 0;
for (auto& [key, arr] : a) {
arr.eval(retain_graph_.value_or(arr.is_tracer()));
if (arr.nbytes() == 0) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key);
}
if (!arr.flags().contiguous) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize a non-contiguous array key: " +
key);
}
json child;
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
child["shape"] = arr.shape();
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
parent[key] = child;
offset += arr.nbytes();
}
auto header = parent.dump();
uint64_t header_len = header.length();
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
out_stream->write(header.c_str(), header_len);
for (auto& [key, arr] : a) {
out_stream->write(arr.data<char>(), arr.nbytes());
}
}
void save_safetensors(
const std::string& file_,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph) {
// Open and check file
std::string file = file_;
// Add .safetensors to file name if it is not there
if (file.length() < 12 ||
file.substr(file.length() - 12, 12) != ".safetensors")
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
}
} // namespace mlx::core

32
mlx/io/safetensor.h Normal file
View File

@ -0,0 +1,32 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <json.hpp>
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
using json = nlohmann::json;
namespace mlx::core {
#define ST_F16 "F16"
#define ST_BF16 "BF16"
#define ST_F32 "F32"
#define ST_BOOL "BOOL"
#define ST_I8 "I8"
#define ST_I16 "I16"
#define ST_I32 "I32"
#define ST_I64 "I64"
#define ST_U8 "U8"
#define ST_U16 "U16"
#define ST_U32 "U32"
#define ST_U64 "U64"
// Note: Complex numbers aren't in the spec yet so this could change -
// https://github.com/huggingface/safetensors/issues/389
#define ST_C64 "C64"
} // namespace mlx::core

View File

@ -7,7 +7,7 @@
#include "array.h" #include "array.h"
#include "device.h" #include "device.h"
#include "load.h" #include "io/load.h"
#include "stream.h" #include "stream.h"
namespace mlx::core { namespace mlx::core {
@ -1057,4 +1057,20 @@ array dequantize(
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
std::unordered_map<std::string, array> load_safetensors(
const std::string& file,
StreamOrDevice s = {});
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>,
std::optional<bool> retain_graph = std::nullopt);
void save_safetensors(
const std::string& file,
std::unordered_map<std::string, array>,
std::optional<bool> retain_graph = std::nullopt);
} // namespace mlx::core } // namespace mlx::core

View File

@ -4,7 +4,7 @@
#include "array.h" #include "array.h"
#include "device.h" #include "device.h"
#include "load.h" #include "io/load.h"
#include "stream.h" #include "stream.h"
#define DEFINE_GRADS() \ #define DEFINE_GRADS() \

View File

@ -6,12 +6,11 @@
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <string_view> #include <string_view>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "mlx/load.h" #include "mlx/io/load.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include "python/src/load.h" #include "python/src/load.h"
@ -161,40 +160,68 @@ class PyFileReader : public io::Reader {
py::object tell_func_; py::object tell_func_;
}; };
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::module_ zipfile = py::module_::import("zipfile"); py::object file,
StreamOrDevice s) {
// Assume .npz file if it is zipped if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
if (is_zip_file(zipfile, file)) { return {load_safetensors(py::cast<std::string>(file), s)};
// Output dictionary filename in zip -> loaded array } else if (is_istream_object(file)) {
std::unordered_map<std::string, array> array_dict;
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st);
// Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
// Remove .npy from file if it is there
auto key = st;
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
key = st.substr(0, st.length() - 4);
// Add array to dict
array_dict.insert({key, arr});
}
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) { auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
{
py::gil_scoped_release gil; py::gil_scoped_release gil;
arr.eval(); for (auto& [key, arr] : arr) {
arr.eval();
}
} }
return {arr};
}
return {array_dict}; throw std::invalid_argument(
} else if (py::isinstance<py::str>(file)) { // Assume .npy file path string "[load_safetensors] Input must be a file-like object, or string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file,
StreamOrDevice s) {
py::module_ zipfile = py::module_::import("zipfile");
if (!is_zip_file(zipfile, file)) {
throw std::invalid_argument(
"[load_npz] Input must be a zip file or a file-like object that can be "
"opened with zipfile.ZipFile");
}
// Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict;
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st);
// Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
// Remove .npy from file if it is there
auto key = st;
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
key = st.substr(0, st.length() - 4);
// Add array to dict
array_dict.insert({key, arr});
}
// If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) {
py::gil_scoped_release gil;
arr.eval();
}
return {array_dict};
}
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
return {load(py::cast<std::string>(file), s)}; return {load(py::cast<std::string>(file), s)};
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
@ -205,9 +232,41 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
} }
return {arr}; return {arr};
} }
throw std::invalid_argument( throw std::invalid_argument(
"[load] Input must be a file-like object, string, or pathlib.Path"); "[load_npy] Input must be a file-like object, or string");
}
DictOrArray mlx_load_helper(
py::object file,
std::optional<std::string> format,
StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (py::isinstance<py::str>(file)) {
fname = py::cast<std::string>(file);
} else if (is_istream_object(file)) {
fname = file.attr("name").cast<std::string>();
} else {
throw std::invalid_argument(
"[load] Input must be a file-like object, or string");
}
size_t ext = fname.find_last_of('.');
if (ext == std::string::npos) {
throw std::invalid_argument(
"[load] Could not infer file format from extension");
}
format.emplace(fname.substr(ext + 1));
}
if (format.value() == "safetensors") {
return mlx_load_safetensor_helper(file, s);
} else if (format.value() == "npz") {
return mlx_load_npz_helper(file, s);
} else if (format.value() == "npy") {
return mlx_load_npy_helper(file, s);
} else {
throw std::invalid_argument("[load] Unknown file format " + format.value());
}
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -305,7 +364,7 @@ void mlx_save_helper(
} }
throw std::invalid_argument( throw std::invalid_argument(
"[save] Input must be a file-like object, string, or pathlib.Path"); "[save] Input must be a file-like object, or string");
} }
void mlx_savez_helper( void mlx_savez_helper(
@ -361,3 +420,25 @@ void mlx_savez_helper(
return; return;
} }
void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<bool> retain_graph) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
save_safetensors(writer, arrays_map, retain_graph);
}
return;
}
throw std::invalid_argument(
"[save_safetensors] Input must be a file-like object, or string");
}

View File

@ -3,6 +3,8 @@
#pragma once #pragma once
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <optional>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <variant> #include <variant>
#include "mlx/ops.h" #include "mlx/ops.h"
@ -12,7 +14,18 @@ using namespace mlx::core;
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>; using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file,
StreamOrDevice s);
void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<bool> retain_graph = std::nullopt);
DictOrArray mlx_load_helper(
py::object file,
std::optional<std::string> format,
StreamOrDevice s);
void mlx_save_helper( void mlx_save_helper(
py::object file, py::object file,
array a, array a,

View File

@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) {
Args: Args:
file (str): File to which the array is saved file (str): File to which the array is saved
arr (array): Array to be saved. arr (array): Array to be saved.
retain_graph (bool, optional): Optional argument to retain graph retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation before saving. If not provided the graph during array evaluation. If left unspecified the graph is retained
is retained if we are during a function transformation. Default: only if saving is done in a function transformation. Default: ``None``
None
)pbdoc"); )pbdoc");
m.def( m.def(
"savez", "savez",
@ -2932,18 +2930,45 @@ void init_ops(py::module_& m) {
&mlx_load_helper, &mlx_load_helper,
"file"_a, "file"_a,
py::pos_only(), py::pos_only(),
"format"_a = none,
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
Load array(s) from a binary file in ``.npy`` or ``.npz`` format. Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
Args: Args:
file (file, str): File in which the array is saved file (file, str): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the format
is inferred from the file extension. Supported formats: ``npy``,
``npz``, and ``safetensors``. Default: ``None``.
Returns: Returns:
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file result (array, dict):
A single array if loading from a ``.npy`` file or a dict mapping
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
)pbdoc");
m.def(
"save_safetensors",
&mlx_save_safetensor_helper,
"file"_a,
"arrays"_a,
py::pos_only(),
"retain_graph"_a = std::nullopt,
py::kw_only(),
R"pbdoc(
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
Save array(s) to a binary file in ``.safetensors`` format.
For more information on the format see https://huggingface.co/docs/safetensors/index.
Args:
file (file, str): File in which the array is saved>
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation. If left unspecified the graph is retained
only if saving is done in a function transformation. Default: ``None``.
)pbdoc"); )pbdoc");
m.def( m.def(
"where", "where",

View File

@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase):
load_arr_mlx_npy = np.load(save_file_mlx) load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
with self.subTest(shape=shape):
save_file_mlx = os.path.join(
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
)
save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
}
with open(save_file_mlx, "wb") as f:
mx.save_safetensors(f, save_dict)
with open(save_file_mlx, "rb") as f:
load_dict = mx.load(f)
self.assertTrue("test" in load_dict)
self.assertTrue(
mx.array_equal(load_dict["test"], save_dict["test"])
)
def test_save_and_load_fs(self): def test_save_and_load_fs(self):
if not os.path.isdir(self.test_dir): if not os.path.isdir(self.test_dir):

View File

@ -14,6 +14,26 @@ std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name); return std::filesystem::temp_directory_path().append(name);
} }
TEST_CASE("test save_safetensors") {
std::string file_path = get_temp_file("test_arr.safetensors");
auto map = std::unordered_map<std::string, array>();
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
map.insert({"test2", ones({2, 2})});
save_safetensors(file_path, map);
auto safeDict = load_safetensors(file_path);
CHECK_EQ(safeDict.size(), 2);
CHECK_EQ(safeDict.count("test"), 1);
CHECK_EQ(safeDict.count("test2"), 1);
array test = safeDict.at("test");
CHECK_EQ(test.dtype(), float32);
CHECK_EQ(test.shape(), std::vector<int>({4}));
CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
array test2 = safeDict.at("test2");
CHECK_EQ(test2.dtype(), float32);
CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
}
TEST_CASE("test single array serialization") { TEST_CASE("test single array serialization") {
// Basic test // Basic test
{ {