awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

34
python/tests/test_eval.py Normal file
View File

@@ -0,0 +1,34 @@
from functools import partial
import unittest
import mlx.core as mx
import mlx_tests
class TestEval(mlx_tests.MLXTestCase):
def test_eval(self):
arrs = [mx.ones((2, 2)) for _ in range(4)]
mx.eval(*arrs)
for x in arrs:
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
def test_retain_graph(self):
def fun(x, retain_graph):
y = 3 * x
mx.eval(y, retain_graph=retain_graph)
return 2 * y
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
with self.assertRaises(ValueError):
dfun_dx_1(mx.array(1.0))
y = dfun_dx_2(mx.array(1.0))
self.assertEqual(y.item(), 6.0)
if __name__ == "__main__":
unittest.main()