diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 57105a74c..20261882b 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -1045,7 +1045,8 @@ std::function(const std::vector&)> checkpoint( std::pair, std::vector> jacfwd( const std::function(const std::vector&)>& fun, - const std::vector& primals) { + const std::vector& primals, + bool has_aux = false) { detail::InTracing in_tracing{false, true}; auto outputs = fun(primals); @@ -1124,4 +1125,144 @@ std::pair jacfwd( auto [outputs, jacobian] = jacfwd(vec_fun, {primal}); return {outputs[0], jacobian[0]}; } -} // namespace mlx::core + +std::function(const std::vector&)> jacrev( + const std::function(const std::vector&)>& fun, + const std::vector& argnums = {0}, + bool has_aux = false, + bool holomorphic = false, + bool allow_int = false) { + return [fun, argnums, has_aux, holomorphic, allow_int]( + const std::vector& inputs) -> std::vector { + for (const auto& input : inputs) { + if (!allow_int && issubdtype(input.dtype(), integer)) { + throw std::invalid_argument( + "[jacrev] Differentiation with respect to integer inputs is not allowed."); + } + } + + std::vector outputs; + auto pullback = + std::function(const std::vector&)>(); + if (!has_aux) { + auto vjp_result = + vjp(fun, inputs, std::vector(outputs.size(), array(1.0f))); + outputs = std::move(vjp_result.first); + pullback = [vjp_result](const std::vector& cotangents) { + return vjp_result.second; + }; + } else { + std::vector aux; + std::vector outputs = fun(inputs); + auto [_, grads] = + vjp(fun, inputs, std::vector(outputs.size(), array(1.0f))); + pullback = [grads](const std::vector& cotangents) { + return grads; + }; + aux = {}; + } + + // Compute the Jacobian row-by-row + std::vector jacobian; + auto basis = [](const std::vector& outputs) { + std::vector basis_vectors; + for (size_t i = 0; i < outputs.size(); ++i) { + array basis_vector = zeros_like(outputs[i]); + basis_vector.data()[i] = 1.0f; + basis_vectors.push_back(basis_vector); + } + return basis_vectors; + }(outputs); + for (const auto& tangent : basis) { + auto row = pullback({tangent}); + jacobian.push_back(row[0]); + } + + // Reshape and return the Jacobian + auto combine_jacobian = + [](const std::vector& jacobian_rows) -> array { + if (jacobian_rows.empty()) { + throw std::invalid_argument( + "[combine_jacobian] Jacobian rows are empty."); + } + // Assuming all rows have the same shape + std::vector combined_shape = { + static_cast(jacobian_rows.size())}; + combined_shape.insert( + combined_shape.end(), + jacobian_rows[0].shape().begin(), + jacobian_rows[0].shape().end()); + array combined = zeros( + combined_shape, + jacobian_rows[0].dtype(), + jacobian_rows[0].primitive().stream()); + for (size_t i = 0; i < jacobian_rows.size(); ++i) { + std::copy( + jacobian_rows[i].data(), + jacobian_rows[i].data() + jacobian_rows[i].size(), + combined.data() + i * jacobian_rows[i].size()); + } + return combined; + }; + auto jacobian_combined = combine_jacobian(jacobian); + return {jacobian_combined}; + }; +} + +std::function(const std::vector&)> hessian( + const std::function(const std::vector&)>& fun, + const std::vector& argnums = {0}, + bool has_aux = false, + bool holomorphic = false) { + return + [fun, argnums, has_aux, holomorphic](const std::vector& inputs) { + // Compute the Jacobian using reverse-mode AD + auto jacobian_fn = jacrev(fun, argnums, has_aux, holomorphic); + auto jacobian = jacobian_fn(inputs); + + // Compute the Hessian by applying forward-mode AD to the Jacobian + auto hessian_fn = jacfwd(jacobian_fn, inputs); + return hessian_fn.first; + }; +} + +std::function jacrev( + const std::function& fun, + int argnum = 0, + bool has_aux = false, + bool holomorphic = false, + bool allow_int = false) { + return [fun, argnum, has_aux, holomorphic, allow_int](const array& input) { + // Wrap the scalar function into a vectorized function + auto vec_fun = [fun](const std::vector& inputs) { + return std::vector{fun(inputs[0])}; + }; + + // Use the existing jacrev implementation for vectorized functions + auto jacobian_fn = + jacrev(vec_fun, {argnum}, has_aux, holomorphic, allow_int); + auto jacobian_result = jacobian_fn({input}); + return jacobian_result[0]; + }; +} + +// Overload for scalar functions +std::function hessian( + const std::function& fun, + int argnum = 0, + bool has_aux = false, + bool holomorphic = false) { + return [fun, argnum, has_aux, holomorphic](const array& input) { + // Wrap the scalar function into a vectorized function + auto vec_fun = [fun](const std::vector& inputs) { + return std::vector{fun(inputs[0])}; + }; + + // Compute the Hessian using the vectorized function + auto hessian_fn = hessian(vec_fun, {argnum}, has_aux, holomorphic); + auto hessian_result = hessian_fn({input}); + return hessian_result[0]; + }; +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/transforms.h b/mlx/transforms.h index 4afb21e23..2c05acab6 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -223,7 +223,82 @@ std::function(const std::vector&)> custom_vjp( * Checkpoint the gradient of a function. Namely, discard all intermediate * state and recalculate it when we need to compute the gradient. */ + +/** + * Computes the Jacobian of a function using reverse-mode AD. + * + * @param fun The function whose Jacobian is to be computed. + * @param argnums The indices of the arguments to differentiate with respect to. + * @param has_aux Whether the function returns auxiliary data. + * @param holomorphic Whether the function is holomorphic. + * @param allow_int Whether to allow differentiation with respect to integer + * inputs. + * @return A function that computes the Jacobian of `fun`. + */ std::function(const std::vector&)> checkpoint( std::function(const std::vector&)> fun); +std::function(const std::vector&)> jacrev( + const std::function(const std::vector&)>& fun, + const std::vector& argnums = {0}, + bool has_aux = false, + bool holomorphic = false, + bool allow_int = false); + +/** + * Computes the Jacobian of a function using forward-mode AD. + * + * @param fun The function whose Jacobian is to be computed. + * @param primals The input arrays to compute the Jacobian with respect to. + * @param has_aux Whether the function returns auxiliary data. + * @return A pair containing the outputs of the function and the Jacobian. + */ +std::pair, std::vector> jacfwd( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + bool has_aux = false); + +/** + * Computes the Hessian of a function. + * + * @param fun The function whose Hessian is to be computed. + * @param argnums The indices of the arguments to differentiate with respect to. + * @param has_aux Whether the function returns auxiliary data. + * @param holomorphic Whether the function is holomorphic. + * @return A function that computes the Hessian of `fun`. + */ +std::function(const std::vector&)> hessian( + const std::function(const std::vector&)>& fun, + const std::vector& argnums = {0}, + bool has_aux = false, + bool holomorphic = false); + +/** + * Overload for scalar functions: Computes the Jacobian of a unary function + * using reverse-mode AD. + */ +std::function jacrev( + const std::function& fun, + int argnum = 0, + bool has_aux = false, + bool holomorphic = false, + bool allow_int = false); + +/** + * Overload for scalar functions: Computes the Jacobian of a unary function + * using forward-mode AD. + */ +std::pair jacfwd( + const std::function& fun, + const array& primal); + +/** + * Overload for scalar functions: Computes the Hessian of a unary function. + */ +std::function hessian( + const std::function& fun, + int argnum = 0, + bool has_aux = false, + bool holomorphic = false); + } // namespace mlx::core diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index c47942b72..8d67e4102 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1506,4 +1506,141 @@ void init_transforms(nb::module_& m) { tree_cache().clear(); mx::detail::compile_clear_cache(); })); + + m.def( + "jacrev", + [](const nb::callable& fun, + const std::optional& argnums, + bool has_aux = false, + bool holomorphic = false, + bool allow_int = false) { + auto [argnums_vec, _] = validate_argnums_argnames(argnums, {}); + return mlx_func( + [fun, argnums_vec, has_aux, holomorphic, allow_int]( + const nb::args& args, const nb::kwargs& kwargs) { + auto inputs = tree_flatten(args, false); + auto jacobian_fn = mx::jacrev( + [&fun](const std::vector& inputs) { + return tree_flatten(fun(*tree_unflatten(args, inputs))); + }, + argnums_vec, + has_aux, + holomorphic, + allow_int); + auto jacobian = jacobian_fn(inputs); + return tree_unflatten(args, jacobian); + }, + fun); + }, + "fun"_a, + "argnums"_a = nb::none(), + "has_aux"_a = false, + "holomorphic"_a = false, + "allow_int"_a = false, + nb::sig( + "def jacrev(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable"), + R"pbdoc( + Compute the Jacobian of a function using reverse-mode AD. + + Args: + fun (Callable): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a variable number of :class:`array` or trees of :class:`array`. + argnums (int or list(int), optional): Specify the index (or indices) + of the positional arguments of ``fun`` to compute the Jacobian + with respect to. Defaults to ``0``. + has_aux (bool, optional): Whether ``fun`` returns auxiliary data. + Defaults to ``False``. + holomorphic (bool, optional): Whether ``fun`` is holomorphic. + Defaults to ``False``. + allow_int (bool, optional): Whether to allow differentiation with + respect to integer inputs. Defaults to ``False``. + + Returns: + Callable: A function which computes the Jacobian of ``fun``. + )pbdoc"); + + m.def( + "jacfwd", + [](const nb::callable& fun, bool has_aux = false) { + return mlx_func( + [fun, has_aux](const nb::args& args, const nb::kwargs& kwargs) { + auto inputs = tree_flatten(args, false); + auto jacobian_fn = mx::jacfwd( + [&fun](const std::vector& inputs) { + return tree_flatten(fun(*tree_unflatten(args, inputs))); + }, + inputs, + has_aux); + auto [outputs, jacobian] = jacobian_fn(inputs); + return std::make_pair( + tree_unflatten(args, outputs), + tree_unflatten(args, jacobian)); + }, + fun); + }, + "fun"_a, + "has_aux"_a = false, + nb::sig("def jacfwd(fun: Callable, has_aux: bool = False) -> Callable"), + R"pbdoc( + Compute the Jacobian of a function using forward-mode AD. + + Args: + fun (Callable): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a variable number of :class:`array` or trees of :class:`array`. + has_aux (bool, optional): Whether ``fun`` returns auxiliary data. + Defaults to ``False``. + + Returns: + Callable: A function which computes the Jacobian of ``fun``. + )pbdoc"); + + m.def( + "hessian", + [](const nb::callable& fun, + const std::optional& argnums, + bool has_aux = false, + bool holomorphic = false) { + auto [argnums_vec, _] = validate_argnums_argnames(argnums, {}); + return mlx_func( + [fun, argnums_vec, has_aux, holomorphic]( + const nb::args& args, const nb::kwargs& kwargs) { + auto inputs = tree_flatten(args, false); + auto hessian_fn = mx::hessian( + [&fun](const std::vector& inputs) { + return tree_flatten(fun(*tree_unflatten(args, inputs))); + }, + argnums_vec, + has_aux, + holomorphic); + auto hessian = hessian_fn(inputs); + return tree_unflatten(args, hessian); + }, + fun); + }, + "fun"_a, + "argnums"_a = nb::none(), + "has_aux"_a = false, + "holomorphic"_a = false, + nb::sig( + "def hessian(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: bool = False, holomorphic: bool = False) -> Callable"), + R"pbdoc( + Compute the Hessian of a function. + + Args: + fun (Callable): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a variable number of :class:`array` or trees of :class:`array`. + argnums (int or list(int), optional): Specify the index (or indices) + of the positional arguments of ``fun`` to compute the Hessian + with respect to. Defaults to ``0``. + has_aux (bool, optional): Whether ``fun`` returns auxiliary data. + Defaults to ``False``. + holomorphic (bool, optional): Whether ``fun`` is holomorphic. + Defaults to ``False``. + + Returns: + Callable: A function which computes the Hessian of ``fun``. + )pbdoc"); } diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index ec9d957ea..7e885a6bb 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -797,6 +797,54 @@ class TestAutograd(mlx_tests.MLXTestCase): grad_fn(model) self.assertEqual(model[1].item(), 2.0) + def test_jacfwd(self): + # Scalar function: f(x) = x^2 + fun = lambda x: x * x + x = mx.array([1.0, 2.0, 3.0]) + outputs, jacobian = mx.jacfwd(fun)(x) + self.assertTrue(mx.array_equal(outputs, x * x)) + self.assertTrue(mx.array_equal(jacobian, 2 * x)) + + # Vectorized function: f(x, y) = [x * y, x + y] + fun = lambda x, y: (x * y, x + y) + x = mx.array(2.0) + y = mx.array(3.0) + outputs, jacobian = mx.jacfwd(fun)(x, y) + self.assertTrue(mx.array_equal(outputs[0], x * y)) + self.assertTrue(mx.array_equal(outputs[1], x + y)) + self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x]))) + self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0]))) + + def test_jacrev(self): + # Scalar function: f(x) = x^2 + fun = lambda x: x * x + x = mx.array([1.0, 2.0, 3.0]) + jacobian = mx.jacrev(fun)(x) + self.assertTrue(mx.array_equal(jacobian, 2 * x)) + + # Vectorized function: f(x, y) = [x * y, x + y] + fun = lambda x, y: (x * y, x + y) + x = mx.array(2.0) + y = mx.array(3.0) + jacobian = mx.jacrev(fun)(x, y) + self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x]))) + self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0]))) + + def test_hessian(self): + # Scalar function: f(x) = x^3 + fun = lambda x: x * x * x + x = mx.array(2.0) + hessian = mx.hessian(fun)(x) + self.assertEqual(hessian.item(), 12.0) + + # Vectorized function: f(x, y) = x^2 + y^2 + fun = lambda x, y: x * x + y * y + x = mx.array(1.0) + y = mx.array(2.0) + hessian = mx.hessian(fun)(x, y) + expected_hessian = mx.array([[2.0, 0.0], [0.0, 2.0]]) + self.assertTrue(mx.array_equal(hessian, expected_hessian)) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index c992c3c6d..9e5efe8db 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1312,3 +1312,32 @@ TEST_CASE("test grad dynamic slices") { CHECK(allclose(outs[1], ones({1, 2})).item()); } } + +TEST_CASE("test jacfwd") { + auto fn = [](array x) { return x * x; }; + + array x = array({1.0, 2.0, 3.0}); + auto [outputs, jacobian] = jacfwd(fn, {x}); + CHECK(array_equal(outputs, x * x).item()); + CHECK(array_equal(jacobian, 2 * x).item()); +} + +TEST_CASE("test jacrev") { + auto fun = [](array x) { return x * x; }; + + array x = array({1.0, 2.0, 3.0}); + auto jacobian_fn = jacrev(fun); + auto jacobian = jacobian_fn({x}); + CHECK(array_equal(jacobian, 2 * x).item()); +} + +TEST_CASE("test hessian") { + auto fun = [](array x) { return sum(x * x); }; + + array x = array({1.0, 2.0, 3.0}); + auto hessian_fn = hessian(fun); + auto hess = hessian_fn({x}); + array expected_hessian = + array({2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0}, {3, 3}); + CHECK(array_equal(hess, expected_hessian).item()); +} \ No newline at end of file