Added jacfwd, jacrev and hessian

This commit is contained in:
paramthakkar123
2025-04-26 20:27:51 +05:30
parent a7a6c49909
commit 0244a15318
5 changed files with 432 additions and 2 deletions

View File

@@ -1506,4 +1506,141 @@ void init_transforms(nb::module_& m) {
tree_cache().clear();
mx::detail::compile_clear_cache();
}));
m.def(
"jacrev",
[](const nb::callable& fun,
const std::optional<IntOrVec>& argnums,
bool has_aux = false,
bool holomorphic = false,
bool allow_int = false) {
auto [argnums_vec, _] = validate_argnums_argnames(argnums, {});
return mlx_func(
[fun, argnums_vec, has_aux, holomorphic, allow_int](
const nb::args& args, const nb::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
auto jacobian_fn = mx::jacrev(
[&fun](const std::vector<mx::array>& inputs) {
return tree_flatten(fun(*tree_unflatten(args, inputs)));
},
argnums_vec,
has_aux,
holomorphic,
allow_int);
auto jacobian = jacobian_fn(inputs);
return tree_unflatten(args, jacobian);
},
fun);
},
"fun"_a,
"argnums"_a = nb::none(),
"has_aux"_a = false,
"holomorphic"_a = false,
"allow_int"_a = false,
nb::sig(
"def jacrev(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable"),
R"pbdoc(
Compute the Jacobian of a function using reverse-mode AD.
Args:
fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices)
of the positional arguments of ``fun`` to compute the Jacobian
with respect to. Defaults to ``0``.
has_aux (bool, optional): Whether ``fun`` returns auxiliary data.
Defaults to ``False``.
holomorphic (bool, optional): Whether ``fun`` is holomorphic.
Defaults to ``False``.
allow_int (bool, optional): Whether to allow differentiation with
respect to integer inputs. Defaults to ``False``.
Returns:
Callable: A function which computes the Jacobian of ``fun``.
)pbdoc");
m.def(
"jacfwd",
[](const nb::callable& fun, bool has_aux = false) {
return mlx_func(
[fun, has_aux](const nb::args& args, const nb::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
auto jacobian_fn = mx::jacfwd(
[&fun](const std::vector<mx::array>& inputs) {
return tree_flatten(fun(*tree_unflatten(args, inputs)));
},
inputs,
has_aux);
auto [outputs, jacobian] = jacobian_fn(inputs);
return std::make_pair(
tree_unflatten(args, outputs),
tree_unflatten(args, jacobian));
},
fun);
},
"fun"_a,
"has_aux"_a = false,
nb::sig("def jacfwd(fun: Callable, has_aux: bool = False) -> Callable"),
R"pbdoc(
Compute the Jacobian of a function using forward-mode AD.
Args:
fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
has_aux (bool, optional): Whether ``fun`` returns auxiliary data.
Defaults to ``False``.
Returns:
Callable: A function which computes the Jacobian of ``fun``.
)pbdoc");
m.def(
"hessian",
[](const nb::callable& fun,
const std::optional<IntOrVec>& argnums,
bool has_aux = false,
bool holomorphic = false) {
auto [argnums_vec, _] = validate_argnums_argnames(argnums, {});
return mlx_func(
[fun, argnums_vec, has_aux, holomorphic](
const nb::args& args, const nb::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
auto hessian_fn = mx::hessian(
[&fun](const std::vector<mx::array>& inputs) {
return tree_flatten(fun(*tree_unflatten(args, inputs)));
},
argnums_vec,
has_aux,
holomorphic);
auto hessian = hessian_fn(inputs);
return tree_unflatten(args, hessian);
},
fun);
},
"fun"_a,
"argnums"_a = nb::none(),
"has_aux"_a = false,
"holomorphic"_a = false,
nb::sig(
"def hessian(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: bool = False, holomorphic: bool = False) -> Callable"),
R"pbdoc(
Compute the Hessian of a function.
Args:
fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices)
of the positional arguments of ``fun`` to compute the Hessian
with respect to. Defaults to ``0``.
has_aux (bool, optional): Whether ``fun`` returns auxiliary data.
Defaults to ``False``.
holomorphic (bool, optional): Whether ``fun`` is holomorphic.
Defaults to ``False``.
Returns:
Callable: A function which computes the Hessian of ``fun``.
)pbdoc");
}

View File

@@ -797,6 +797,54 @@ class TestAutograd(mlx_tests.MLXTestCase):
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)
def test_jacfwd(self):
# Scalar function: f(x) = x^2
fun = lambda x: x * x
x = mx.array([1.0, 2.0, 3.0])
outputs, jacobian = mx.jacfwd(fun)(x)
self.assertTrue(mx.array_equal(outputs, x * x))
self.assertTrue(mx.array_equal(jacobian, 2 * x))
# Vectorized function: f(x, y) = [x * y, x + y]
fun = lambda x, y: (x * y, x + y)
x = mx.array(2.0)
y = mx.array(3.0)
outputs, jacobian = mx.jacfwd(fun)(x, y)
self.assertTrue(mx.array_equal(outputs[0], x * y))
self.assertTrue(mx.array_equal(outputs[1], x + y))
self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x])))
self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0])))
def test_jacrev(self):
# Scalar function: f(x) = x^2
fun = lambda x: x * x
x = mx.array([1.0, 2.0, 3.0])
jacobian = mx.jacrev(fun)(x)
self.assertTrue(mx.array_equal(jacobian, 2 * x))
# Vectorized function: f(x, y) = [x * y, x + y]
fun = lambda x, y: (x * y, x + y)
x = mx.array(2.0)
y = mx.array(3.0)
jacobian = mx.jacrev(fun)(x, y)
self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x])))
self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0])))
def test_hessian(self):
# Scalar function: f(x) = x^3
fun = lambda x: x * x * x
x = mx.array(2.0)
hessian = mx.hessian(fun)(x)
self.assertEqual(hessian.item(), 12.0)
# Vectorized function: f(x, y) = x^2 + y^2
fun = lambda x, y: x * x + y * y
x = mx.array(1.0)
y = mx.array(2.0)
hessian = mx.hessian(fun)(x, y)
expected_hessian = mx.array([[2.0, 0.0], [0.0, 2.0]])
self.assertTrue(mx.array_equal(hessian, expected_hessian))
if __name__ == "__main__":
unittest.main()