mlx/python/tests/test_eval.py

50 lines
1.2 KiB
Python
Raw Normal View History

2023-12-01 03:12:53 +08:00
# Copyright © 2023 Apple Inc.
2023-11-30 02:30:41 +08:00
import unittest
2023-12-09 03:31:47 +08:00
from functools import partial
2023-11-30 02:30:41 +08:00
import mlx.core as mx
import mlx_tests
class TestEval(mlx_tests.MLXTestCase):
def test_eval(self):
arrs = [mx.ones((2, 2)) for _ in range(4)]
mx.eval(*arrs)
for x in arrs:
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
def test_retain_graph(self):
def fun(x):
2023-11-30 02:30:41 +08:00
y = 3 * x
mx.eval(y)
2023-11-30 02:30:41 +08:00
return 2 * y
dfun_dx = mx.grad(fun)
y = dfun_dx(mx.array(1.0))
2023-11-30 02:30:41 +08:00
self.assertEqual(y.item(), 6.0)
def test_eval_mixed(self):
x = mx.array(1) + 1 + 1
y = 0
z = "hello"
state = [x, y, z]
mx.eval(state)
self.assertEqual(x.item(), 3)
2024-04-10 09:34:00 +08:00
def test_async_eval(self):
x = mx.array(1) + mx.array(1) + mx.array(1)
sync = mx.async_eval(x)
sync.wait()
self.assertEqual(x.item(), 3)
# It should be safe to call eval on the array which has been async
# eval'ed
x = mx.array(1) + mx.array(1) + mx.array(1)
sync = mx.async_eval(x)
self.assertEqual(x.item(), 3)
2023-11-30 02:30:41 +08:00
if __name__ == "__main__":
unittest.main()