mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
updated tests
This commit is contained in:
parent
f09bcc7d50
commit
18a1c335d0
@ -161,15 +161,13 @@ TEST_CASE("test jsonDeserialize") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test save_safetensor") {
|
TEST_CASE("test save_safetensor") {
|
||||||
|
std::string file_path = get_temp_file("test_arr.safetensors");
|
||||||
auto map = std::unordered_map<std::string, array>();
|
auto map = std::unordered_map<std::string, array>();
|
||||||
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||||
map.insert({"test2", ones({2, 2})});
|
map.insert({"test2", ones({2, 2})});
|
||||||
MESSAGE("SAVING");
|
MESSAGE("SAVING");
|
||||||
save_safetensor("../../temp1", map);
|
save_safetensor(file_path, map);
|
||||||
}
|
auto safeDict = load_safetensor(file_path);
|
||||||
|
|
||||||
TEST_CASE("test load_safetensor") {
|
|
||||||
auto safeDict = load_safetensor("../../temp1.safetensors");
|
|
||||||
CHECK_EQ(safeDict.size(), 2);
|
CHECK_EQ(safeDict.size(), 2);
|
||||||
CHECK_EQ(safeDict.count("test"), 1);
|
CHECK_EQ(safeDict.count("test"), 1);
|
||||||
CHECK_EQ(safeDict.count("test2"), 1);
|
CHECK_EQ(safeDict.count("test2"), 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user