mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Added jacfwd, jacrev and hessian
This commit is contained in:
@@ -797,6 +797,54 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_jacfwd(self):
|
||||
# Scalar function: f(x) = x^2
|
||||
fun = lambda x: x * x
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
outputs, jacobian = mx.jacfwd(fun)(x)
|
||||
self.assertTrue(mx.array_equal(outputs, x * x))
|
||||
self.assertTrue(mx.array_equal(jacobian, 2 * x))
|
||||
|
||||
# Vectorized function: f(x, y) = [x * y, x + y]
|
||||
fun = lambda x, y: (x * y, x + y)
|
||||
x = mx.array(2.0)
|
||||
y = mx.array(3.0)
|
||||
outputs, jacobian = mx.jacfwd(fun)(x, y)
|
||||
self.assertTrue(mx.array_equal(outputs[0], x * y))
|
||||
self.assertTrue(mx.array_equal(outputs[1], x + y))
|
||||
self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x])))
|
||||
self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0])))
|
||||
|
||||
def test_jacrev(self):
|
||||
# Scalar function: f(x) = x^2
|
||||
fun = lambda x: x * x
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
jacobian = mx.jacrev(fun)(x)
|
||||
self.assertTrue(mx.array_equal(jacobian, 2 * x))
|
||||
|
||||
# Vectorized function: f(x, y) = [x * y, x + y]
|
||||
fun = lambda x, y: (x * y, x + y)
|
||||
x = mx.array(2.0)
|
||||
y = mx.array(3.0)
|
||||
jacobian = mx.jacrev(fun)(x, y)
|
||||
self.assertTrue(mx.array_equal(jacobian[0], mx.array([y, x])))
|
||||
self.assertTrue(mx.array_equal(jacobian[1], mx.array([1.0, 1.0])))
|
||||
|
||||
def test_hessian(self):
|
||||
# Scalar function: f(x) = x^3
|
||||
fun = lambda x: x * x * x
|
||||
x = mx.array(2.0)
|
||||
hessian = mx.hessian(fun)(x)
|
||||
self.assertEqual(hessian.item(), 12.0)
|
||||
|
||||
# Vectorized function: f(x, y) = x^2 + y^2
|
||||
fun = lambda x, y: x * x + y * y
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
hessian = mx.hessian(fun)(x, y)
|
||||
expected_hessian = mx.array([[2.0, 0.0], [0.0, 2.0]])
|
||||
self.assertTrue(mx.array_equal(hessian, expected_hessian))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user