mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Added jacfwd, jacrev and hessian
This commit is contained in:
parent
a7a6c49909
commit
0244a15318
@ -1045,7 +1045,8 @@ std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> jacfwd(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals) {
|
||||
const std::vector<array>& primals,
|
||||
bool has_aux = false) {
|
||||
detail::InTracing in_tracing{false, true};
|
||||
auto outputs = fun(primals);
|
||||
|
||||
@ -1124,4 +1125,144 @@ std::pair<array, array> jacfwd(
|
||||
auto [outputs, jacobian] = jacfwd(vec_fun, {primal});
|
||||
return {outputs[0], jacobian[0]};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> jacrev(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums = {0},
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false,
|
||||
bool allow_int = false) {
|
||||
return [fun, argnums, has_aux, holomorphic, allow_int](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
for (const auto& input : inputs) {
|
||||
if (!allow_int && issubdtype(input.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[jacrev] Differentiation with respect to integer inputs is not allowed.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
auto pullback =
|
||||
std::function<std::vector<array>(const std::vector<array>&)>();
|
||||
if (!has_aux) {
|
||||
auto vjp_result =
|
||||
vjp(fun, inputs, std::vector<array>(outputs.size(), array(1.0f)));
|
||||
outputs = std::move(vjp_result.first);
|
||||
pullback = [vjp_result](const std::vector<array>& cotangents) {
|
||||
return vjp_result.second;
|
||||
};
|
||||
} else {
|
||||
std::vector<array> aux;
|
||||
std::vector<array> outputs = fun(inputs);
|
||||
auto [_, grads] =
|
||||
vjp(fun, inputs, std::vector<array>(outputs.size(), array(1.0f)));
|
||||
pullback = [grads](const std::vector<array>& cotangents) {
|
||||
return grads;
|
||||
};
|
||||
aux = {};
|
||||
}
|
||||
|
||||
// Compute the Jacobian row-by-row
|
||||
std::vector<array> jacobian;
|
||||
auto basis = [](const std::vector<array>& outputs) {
|
||||
std::vector<array> basis_vectors;
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
array basis_vector = zeros_like(outputs[i]);
|
||||
basis_vector.data<float>()[i] = 1.0f;
|
||||
basis_vectors.push_back(basis_vector);
|
||||
}
|
||||
return basis_vectors;
|
||||
}(outputs);
|
||||
for (const auto& tangent : basis) {
|
||||
auto row = pullback({tangent});
|
||||
jacobian.push_back(row[0]);
|
||||
}
|
||||
|
||||
// Reshape and return the Jacobian
|
||||
auto combine_jacobian =
|
||||
[](const std::vector<array>& jacobian_rows) -> array {
|
||||
if (jacobian_rows.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[combine_jacobian] Jacobian rows are empty.");
|
||||
}
|
||||
// Assuming all rows have the same shape
|
||||
std::vector<int> combined_shape = {
|
||||
static_cast<int>(jacobian_rows.size())};
|
||||
combined_shape.insert(
|
||||
combined_shape.end(),
|
||||
jacobian_rows[0].shape().begin(),
|
||||
jacobian_rows[0].shape().end());
|
||||
array combined = zeros(
|
||||
combined_shape,
|
||||
jacobian_rows[0].dtype(),
|
||||
jacobian_rows[0].primitive().stream());
|
||||
for (size_t i = 0; i < jacobian_rows.size(); ++i) {
|
||||
std::copy(
|
||||
jacobian_rows[i].data<float>(),
|
||||
jacobian_rows[i].data<float>() + jacobian_rows[i].size(),
|
||||
combined.data<float>() + i * jacobian_rows[i].size());
|
||||
}
|
||||
return combined;
|
||||
};
|
||||
auto jacobian_combined = combine_jacobian(jacobian);
|
||||
return {jacobian_combined};
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> hessian(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums = {0},
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false) {
|
||||
return
|
||||
[fun, argnums, has_aux, holomorphic](const std::vector<array>& inputs) {
|
||||
// Compute the Jacobian using reverse-mode AD
|
||||
auto jacobian_fn = jacrev(fun, argnums, has_aux, holomorphic);
|
||||
auto jacobian = jacobian_fn(inputs);
|
||||
|
||||
// Compute the Hessian by applying forward-mode AD to the Jacobian
|
||||
auto hessian_fn = jacfwd(jacobian_fn, inputs);
|
||||
return hessian_fn.first;
|
||||
};
|
||||
}
|
||||
|
||||
std::function<array(const array&)> jacrev(
|
||||
const std::function<array(const array&)>& fun,
|
||||
int argnum = 0,
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false,
|
||||
bool allow_int = false) {
|
||||
return [fun, argnum, has_aux, holomorphic, allow_int](const array& input) {
|
||||
// Wrap the scalar function into a vectorized function
|
||||
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{fun(inputs[0])};
|
||||
};
|
||||
|
||||
// Use the existing jacrev implementation for vectorized functions
|
||||
auto jacobian_fn =
|
||||
jacrev(vec_fun, {argnum}, has_aux, holomorphic, allow_int);
|
||||
auto jacobian_result = jacobian_fn({input});
|
||||
return jacobian_result[0];
|
||||
};
|
||||
}
|
||||
|
||||
// Overload for scalar functions
|
||||
std::function<array(const array&)> hessian(
|
||||
const std::function<array(const array&)>& fun,
|
||||
int argnum = 0,
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false) {
|
||||
return [fun, argnum, has_aux, holomorphic](const array& input) {
|
||||
// Wrap the scalar function into a vectorized function
|
||||
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{fun(inputs[0])};
|
||||
};
|
||||
|
||||
// Compute the Hessian using the vectorized function
|
||||
auto hessian_fn = hessian(vec_fun, {argnum}, has_aux, holomorphic);
|
||||
auto hessian_result = hessian_fn({input});
|
||||
return hessian_result[0];
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -223,7 +223,82 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
* Checkpoint the gradient of a function. Namely, discard all intermediate
|
||||
* state and recalculate it when we need to compute the gradient.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Computes the Jacobian of a function using reverse-mode AD.
|
||||
*
|
||||
* @param fun The function whose Jacobian is to be computed.
|
||||
* @param argnums The indices of the arguments to differentiate with respect to.
|
||||
* @param has_aux Whether the function returns auxiliary data.
|
||||
* @param holomorphic Whether the function is holomorphic.
|
||||
* @param allow_int Whether to allow differentiation with respect to integer
|
||||
* inputs.
|
||||
* @return A function that computes the Jacobian of `fun`.
|
||||
*/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun);
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> jacrev(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums = {0},
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false,
|
||||
bool allow_int = false);
|
||||
|
||||
/**
|
||||
* Computes the Jacobian of a function using forward-mode AD.
|
||||
*
|
||||
* @param fun The function whose Jacobian is to be computed.
|
||||
* @param primals The input arrays to compute the Jacobian with respect to.
|
||||
* @param has_aux Whether the function returns auxiliary data.
|
||||
* @return A pair containing the outputs of the function and the Jacobian.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<array>> jacfwd(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
bool has_aux = false);
|
||||
|
||||
/**
|
||||
* Computes the Hessian of a function.
|
||||
*
|
||||
* @param fun The function whose Hessian is to be computed.
|
||||
* @param argnums The indices of the arguments to differentiate with respect to.
|
||||
* @param has_aux Whether the function returns auxiliary data.
|
||||
* @param holomorphic Whether the function is holomorphic.
|
||||
* @return A function that computes the Hessian of `fun`.
|
||||
*/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> hessian(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums = {0},
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false);
|
||||
|
||||
/**
|
||||
* Overload for scalar functions: Computes the Jacobian of a unary function
|
||||
* using reverse-mode AD.
|
||||
*/
|
||||
std::function<array(const array&)> jacrev(
|
||||
const std::function<array(const array&)>& fun,
|
||||
int argnum = 0,
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false,
|
||||
bool allow_int = false);
|
||||
|
||||
/**
|
||||
* Overload for scalar functions: Computes the Jacobian of a unary function
|
||||
* using forward-mode AD.
|
||||
*/
|
||||
std::pair<array, array> jacfwd(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal);
|
||||
|
||||
/**
|
||||
* Overload for scalar functions: Computes the Hessian of a unary function.
|
||||
*/
|
||||
std::function<array(const array&)> hessian(
|
||||
const std::function<array(const array&)>& fun,
|
||||
int argnum = 0,
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1506,4 +1506,141 @@ void init_transforms(nb::module_& m) {
|
||||
tree_cache().clear();
|
||||
mx::detail::compile_clear_cache();
|
||||
}));
|
||||
|
||||
m.def(
|
||||
"jacrev",
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false,
|
||||
bool allow_int = false) {
|
||||
auto [argnums_vec, _] = validate_argnums_argnames(argnums, {});
|
||||
return mlx_func(
|
||||
[fun, argnums_vec, has_aux, holomorphic, allow_int](
|
||||
const nb::args& args, const nb::kwargs& kwargs) {
|
||||
auto inputs = tree_flatten(args, false);
|
||||
auto jacobian_fn = mx::jacrev(
|
||||
[&fun](const std::vector<mx::array>& inputs) {
|
||||
return tree_flatten(fun(*tree_unflatten(args, inputs)));
|
||||
},
|
||||
argnums_vec,
|
||||
has_aux,
|
||||
holomorphic,
|
||||
allow_int);
|
||||
auto jacobian = jacobian_fn(inputs);
|
||||
return tree_unflatten(args, jacobian);
|
||||
},
|
||||
fun);
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
"has_aux"_a = false,
|
||||
"holomorphic"_a = false,
|
||||
"allow_int"_a = false,
|
||||
nb::sig(
|
||||
"def jacrev(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable"),
|
||||
R"pbdoc(
|
||||
Compute the Jacobian of a function using reverse-mode AD.
|
||||
|
||||
Args:
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
of the positional arguments of ``fun`` to compute the Jacobian
|
||||
with respect to. Defaults to ``0``.
|
||||
has_aux (bool, optional): Whether ``fun`` returns auxiliary data.
|
||||
Defaults to ``False``.
|
||||
holomorphic (bool, optional): Whether ``fun`` is holomorphic.
|
||||
Defaults to ``False``.
|
||||
allow_int (bool, optional): Whether to allow differentiation with
|
||||
respect to integer inputs. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
Callable: A function which computes the Jacobian of ``fun``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"jacfwd",
|
||||
[](const nb::callable& fun, bool has_aux = false) {
|
||||
return mlx_func(
|
||||
[fun, has_aux](const nb::args& args, const nb::kwargs& kwargs) {
|
||||
auto inputs = tree_flatten(args, false);
|
||||
auto jacobian_fn = mx::jacfwd(
|
||||
[&fun](const std::vector<mx::array>& inputs) {
|
||||
return tree_flatten(fun(*tree_unflatten(args, inputs)));
|
||||
},
|
||||
inputs,
|
||||
has_aux);
|
||||
auto [outputs, jacobian] = jacobian_fn(inputs);
|
||||
return std::make_pair(
|
||||
tree_unflatten(args, outputs),
|
||||
tree_unflatten(args, jacobian));
|
||||
},
|
||||
fun);
|
||||
},
|
||||
"fun"_a,
|
||||
"has_aux"_a = false,
|
||||
nb::sig("def jacfwd(fun: Callable, has_aux: bool = False) -> Callable"),
|
||||
R"pbdoc(
|
||||
Compute the Jacobian of a function using forward-mode AD.
|
||||
|
||||
Args:
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
has_aux (bool, optional): Whether ``fun`` returns auxiliary data.
|
||||
Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
Callable: A function which computes the Jacobian of ``fun``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"hessian",
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
bool has_aux = false,
|
||||
bool holomorphic = false) {
|
||||
auto [argnums_vec, _] = validate_argnums_argnames(argnums, {});
|
||||
return mlx_func(
|
||||
[fun, argnums_vec, has_aux, holomorphic](
|
||||
const nb::args& args, const nb::kwargs& kwargs) {
|
||||
auto inputs = tree_flatten(args, false);
|
||||
auto hessian_fn = mx::hessian(
|
||||
[&fun](const std::vector<mx::array>& inputs) {
|
||||
return tree_flatten(fun(*tree_unflatten(args, inputs)));
|
||||
},
|
||||
argnums_vec,
|
||||
has_aux,
|
||||
holomorphic);
|
||||
auto hessian = hessian_fn(inputs);
|
||||
return tree_unflatten(args, hessian);
|
||||
},
|
||||
fun);
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
"has_aux"_a = false,
|
||||
"holomorphic"_a = false,
|
||||
nb::sig(
|
||||
"def hessian(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: bool = False, holomorphic: bool = False) -> Callable"),
|
||||
R"pbdoc(
|
||||
Compute the Hessian of a function.
|
||||
|
||||
Args:
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
of the positional arguments of ``fun`` to compute the Hessian
|
||||
with respect to. Defaults to ``0``.
|
||||
has_aux (bool, optional): Whether ``fun`` returns auxiliary data.
|
||||
Defaults to ``False``.
|
||||
holomorphic (bool, optional): Whether ``fun`` is holomorphic.
|
||||
Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
Callable: A function which computes the Hessian of ``fun``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -797,6 +797,54 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_jacfwd(self):
|
||||
# Scalar function: f(x) = x^2
|
||||
fun = lambda x: x * x
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
outputs, jacobian = mx.jacfwd(fun)(x)
|
||||
self.assertTrue(mx.array_equal(outputs, x * x))
|
||||
self.assertTrue(mx.array_equal(jacobian, 2 * x))
|
||||
|
||||
# Vectorized function: f(x, y) = [x * y, x + y]
|
||||
fun = lambda x, y: (x * y, x + y)
|
||||
x = mx.array(2.0)
|
||||
y = mx.array(3.0)
|
||||
outputs, jacobian = mx.jacfwd(fun)(x, y)
|
||||
self.assertTrue(mx.array_equal(outputs[0], x * y))
|
||||
self.assertTrue(mx.array_equal(outputs[1], x + y))
|
||||
self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x])))
|
||||
self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0])))
|
||||
|
||||
def test_jacrev(self):
|
||||
# Scalar function: f(x) = x^2
|
||||
fun = lambda x: x * x
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
jacobian = mx.jacrev(fun)(x)
|
||||
self.assertTrue(mx.array_equal(jacobian, 2 * x))
|
||||
|
||||
# Vectorized function: f(x, y) = [x * y, x + y]
|
||||
fun = lambda x, y: (x * y, x + y)
|
||||
x = mx.array(2.0)
|
||||
y = mx.array(3.0)
|
||||
jacobian = mx.jacrev(fun)(x, y)
|
||||
self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x])))
|
||||
self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0])))
|
||||
|
||||
def test_hessian(self):
|
||||
# Scalar function: f(x) = x^3
|
||||
fun = lambda x: x * x * x
|
||||
x = mx.array(2.0)
|
||||
hessian = mx.hessian(fun)(x)
|
||||
self.assertEqual(hessian.item(), 12.0)
|
||||
|
||||
# Vectorized function: f(x, y) = x^2 + y^2
|
||||
fun = lambda x, y: x * x + y * y
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
hessian = mx.hessian(fun)(x, y)
|
||||
expected_hessian = mx.array([[2.0, 0.0], [0.0, 2.0]])
|
||||
self.assertTrue(mx.array_equal(hessian, expected_hessian))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1312,3 +1312,32 @@ TEST_CASE("test grad dynamic slices") {
|
||||
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test jacfwd") {
|
||||
auto fn = [](array x) { return x * x; };
|
||||
|
||||
array x = array({1.0, 2.0, 3.0});
|
||||
auto [outputs, jacobian] = jacfwd(fn, {x});
|
||||
CHECK(array_equal(outputs, x * x).item<bool>());
|
||||
CHECK(array_equal(jacobian, 2 * x).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test jacrev") {
|
||||
auto fun = [](array x) { return x * x; };
|
||||
|
||||
array x = array({1.0, 2.0, 3.0});
|
||||
auto jacobian_fn = jacrev(fun);
|
||||
auto jacobian = jacobian_fn({x});
|
||||
CHECK(array_equal(jacobian, 2 * x).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test hessian") {
|
||||
auto fun = [](array x) { return sum(x * x); };
|
||||
|
||||
array x = array({1.0, 2.0, 3.0});
|
||||
auto hessian_fn = hessian(fun);
|
||||
auto hess = hessian_fn({x});
|
||||
array expected_hessian =
|
||||
array({2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0}, {3, 3});
|
||||
CHECK(array_equal(hess, expected_hessian).item<bool>());
|
||||
}
|
Loading…
Reference in New Issue
Block a user