mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
@@ -14,6 +14,26 @@ std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
TEST_CASE("test save_safetensors") {
|
||||
std::string file_path = get_temp_file("test_arr.safetensors");
|
||||
auto map = std::unordered_map<std::string, array>();
|
||||
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||
map.insert({"test2", ones({2, 2})});
|
||||
save_safetensors(file_path, map);
|
||||
auto safeDict = load_safetensors(file_path);
|
||||
CHECK_EQ(safeDict.size(), 2);
|
||||
CHECK_EQ(safeDict.count("test"), 1);
|
||||
CHECK_EQ(safeDict.count("test2"), 1);
|
||||
array test = safeDict.at("test");
|
||||
CHECK_EQ(test.dtype(), float32);
|
||||
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
||||
CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
|
||||
array test2 = safeDict.at("test2");
|
||||
CHECK_EQ(test2.dtype(), float32);
|
||||
CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
|
||||
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test single array serialization") {
|
||||
// Basic test
|
||||
{
|
||||
|
Reference in New Issue
Block a user