Added jacfwd, jacrev and hessian

This commit is contained in:
paramthakkar123
2025-04-26 20:27:51 +05:30
parent a7a6c49909
commit 0244a15318
5 changed files with 432 additions and 2 deletions

View File

@@ -1312,3 +1312,32 @@ TEST_CASE("test grad dynamic slices") {
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
}
}
TEST_CASE("test jacfwd") {
auto fn = [](array x) { return x * x; };
array x = array({1.0, 2.0, 3.0});
auto [outputs, jacobian] = jacfwd(fn, {x});
CHECK(array_equal(outputs, x * x).item<bool>());
CHECK(array_equal(jacobian, 2 * x).item<bool>());
}
TEST_CASE("test jacrev") {
auto fun = [](array x) { return x * x; };
array x = array({1.0, 2.0, 3.0});
auto jacobian_fn = jacrev(fun);
auto jacobian = jacobian_fn({x});
CHECK(array_equal(jacobian, 2 * x).item<bool>());
}
TEST_CASE("test hessian") {
auto fun = [](array x) { return sum(x * x); };
array x = array({1.0, 2.0, 3.0});
auto hessian_fn = hessian(fun);
auto hess = hessian_fn({x});
array expected_hessian =
array({2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0}, {3, 3});
CHECK(array_equal(hess, expected_hessian).item<bool>());
}