mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Added jacfwd, jacrev and hessian
This commit is contained in:
@@ -1312,3 +1312,32 @@ TEST_CASE("test grad dynamic slices") {
|
||||
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test jacfwd") {
|
||||
auto fn = [](array x) { return x * x; };
|
||||
|
||||
array x = array({1.0, 2.0, 3.0});
|
||||
auto [outputs, jacobian] = jacfwd(fn, {x});
|
||||
CHECK(array_equal(outputs, x * x).item<bool>());
|
||||
CHECK(array_equal(jacobian, 2 * x).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test jacrev") {
|
||||
auto fun = [](array x) { return x * x; };
|
||||
|
||||
array x = array({1.0, 2.0, 3.0});
|
||||
auto jacobian_fn = jacrev(fun);
|
||||
auto jacobian = jacobian_fn({x});
|
||||
CHECK(array_equal(jacobian, 2 * x).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test hessian") {
|
||||
auto fun = [](array x) { return sum(x * x); };
|
||||
|
||||
array x = array({1.0, 2.0, 3.0});
|
||||
auto hessian_fn = hessian(fun);
|
||||
auto hess = hessian_fn({x});
|
||||
array expected_hessian =
|
||||
array({2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0}, {3, 3});
|
||||
CHECK(array_equal(hess, expected_hessian).item<bool>());
|
||||
}
|
Reference in New Issue
Block a user