mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (tests)
This commit is contained in:
@@ -145,7 +145,7 @@ TEST_CASE("test jvp") {
|
||||
|
||||
// No dependence between input and output
|
||||
{
|
||||
auto fun = [](array in) { return array({1.0, 1.0}); };
|
||||
auto fun = [](array /* in */) { return array({1.0, 1.0}); };
|
||||
auto out = jvp(fun, array(1.0f), array(1.0f)).second;
|
||||
CHECK(array_equal(out, zeros({2})).item<bool>());
|
||||
}
|
||||
@@ -195,7 +195,7 @@ TEST_CASE("test vjp") {
|
||||
|
||||
// No dependence between input and output
|
||||
{
|
||||
auto fun = [](array in) { return array(1.); };
|
||||
auto fun = [](array /* in */) { return array(1.); };
|
||||
auto out = vjp(fun, zeros({2}), array(1.)).second;
|
||||
CHECK(array_equal(out, zeros({2})).item<bool>());
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ TEST_CASE("test export basic functions") {
|
||||
}
|
||||
|
||||
TEST_CASE("test export function with no inputs") {
|
||||
auto fun = [](std::vector<array> x) -> std::vector<array> {
|
||||
auto fun = [](std::vector<array> /* x */) -> std::vector<array> {
|
||||
return {zeros({2, 2})};
|
||||
};
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ TEST_CASE("test gguf metadata") {
|
||||
CHECK_EQ(loaded_metadata.count("meta"), 1);
|
||||
auto& strs = std::get<std::vector<std::string>>(loaded_metadata["meta"]);
|
||||
CHECK_EQ(strs.size(), 3);
|
||||
for (int i = 0; i < strs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(strs); ++i) {
|
||||
CHECK_EQ(strs[i], data[i]);
|
||||
}
|
||||
}
|
||||
@@ -187,7 +187,7 @@ TEST_CASE("test gguf metadata") {
|
||||
CHECK_EQ(loaded_metadata.size(), 4);
|
||||
auto& strs = std::get<std::vector<std::string>>(loaded_metadata["meta1"]);
|
||||
CHECK_EQ(strs.size(), 3);
|
||||
for (int i = 0; i < strs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(strs); ++i) {
|
||||
CHECK_EQ(strs[i], data[i]);
|
||||
}
|
||||
auto& arr = std::get<array>(loaded_metadata["meta2"]);
|
||||
|
||||
@@ -1668,7 +1668,7 @@ TEST_CASE("test error functions") {
|
||||
-0.1124629160182849,
|
||||
-0.5204998778130465,
|
||||
-0.7969082124228322};
|
||||
for (int i = 0; i < vals.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(vals); ++i) {
|
||||
x = array(vals.begin()[i]);
|
||||
CHECK_EQ(erf(x).item<float>(), doctest::Approx(expected.begin()[i]));
|
||||
}
|
||||
@@ -1686,7 +1686,7 @@ TEST_CASE("test error functions") {
|
||||
-0.08885599049425769,
|
||||
-0.4769362762044699,
|
||||
-1.1630871536766743};
|
||||
for (int i = 0; i < vals.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(vals); ++i) {
|
||||
x = array(vals.begin()[i]);
|
||||
CHECK_EQ(erfinv(x).item<float>(), doctest::Approx(expected.begin()[i]));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user