mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
fix grad, more tests
This commit is contained in:
@@ -8,13 +8,42 @@ using namespace mlx::core;
|
||||
|
||||
std::vector<array> simple_fun(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + inputs[1]};
|
||||
};
|
||||
}
|
||||
|
||||
TEST_CASE("test simple compile") {
|
||||
auto compfn = compile(simple_fun);
|
||||
auto out = compfn({array(1.0), array(2.0)})[0];
|
||||
auto out = compfn({array(1.0f), array(2.0f)})[0];
|
||||
CHECK_EQ(out.item<float>(), 3.0f);
|
||||
|
||||
out = compfn({array(1.0), array(2.0)})[0];
|
||||
out = compfn({array(1.0f), array(2.0f)})[0];
|
||||
CHECK_EQ(out.item<float>(), 3.0f);
|
||||
|
||||
// Change the shapes
|
||||
out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
|
||||
out = compfn({array(2.0f), array({1.0f, 2.0f})})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
|
||||
// Change the types
|
||||
out = compfn({array(2, int32), array({1.0f, 2.0f})})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
|
||||
out = compfn({array(2.0f), array({1, 2}, int32)})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
}
|
||||
|
||||
std::vector<array> fun1(const std::vector<array>& inputs) {
|
||||
auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
|
||||
return grad(loss)(inputs);
|
||||
}
|
||||
|
||||
TEST_CASE("test compile with grad") {
|
||||
auto x = array(1.0f);
|
||||
auto y = array(1.0f);
|
||||
auto grads_expected = fun1({x, y});
|
||||
auto grads_compile = compile(fun1)({x, y});
|
||||
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
|
||||
}
|
||||
|
||||
TEST_CASE("test nested compile") {}
|
||||
|
||||
Reference in New Issue
Block a user