2023-12-01 03:12:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
import unittest
|
2023-12-09 03:31:47 +08:00
|
|
|
from functools import partial
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx_tests
|
|
|
|
|
|
|
|
|
|
|
|
class TestEval(mlx_tests.MLXTestCase):
|
|
|
|
def test_eval(self):
|
|
|
|
arrs = [mx.ones((2, 2)) for _ in range(4)]
|
|
|
|
mx.eval(*arrs)
|
|
|
|
for x in arrs:
|
|
|
|
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
|
|
|
|
|
|
|
def test_retain_graph(self):
|
2024-01-08 07:16:51 +08:00
|
|
|
def fun(x):
|
2023-11-30 02:30:41 +08:00
|
|
|
y = 3 * x
|
2024-01-08 07:16:51 +08:00
|
|
|
mx.eval(y)
|
2023-11-30 02:30:41 +08:00
|
|
|
return 2 * y
|
|
|
|
|
2024-01-08 07:16:51 +08:00
|
|
|
dfun_dx = mx.grad(fun)
|
|
|
|
y = dfun_dx(mx.array(1.0))
|
2023-11-30 02:30:41 +08:00
|
|
|
self.assertEqual(y.item(), 6.0)
|
|
|
|
|
2024-02-08 09:29:22 +08:00
|
|
|
def test_eval_mixed(self):
|
|
|
|
x = mx.array(1) + 1 + 1
|
|
|
|
y = 0
|
|
|
|
z = "hello"
|
|
|
|
state = [x, y, z]
|
|
|
|
mx.eval(state)
|
|
|
|
self.assertEqual(x.item(), 3)
|
|
|
|
|
2024-04-10 09:34:00 +08:00
|
|
|
def test_async_eval(self):
|
|
|
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
2024-04-17 21:16:02 +08:00
|
|
|
mx.async_eval(x)
|
2024-04-10 09:34:00 +08:00
|
|
|
self.assertEqual(x.item(), 3)
|
|
|
|
|
|
|
|
# It should be safe to call eval on the array which has been async
|
|
|
|
# eval'ed
|
|
|
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
|
|
|
self.assertEqual(x.item(), 3)
|
|
|
|
|
2024-04-17 21:16:02 +08:00
|
|
|
x = mx.array([1, 2, 3])
|
|
|
|
y = 2 * x
|
|
|
|
mx.async_eval(y)
|
|
|
|
z = 2 * y
|
|
|
|
mx.async_eval(z)
|
|
|
|
self.assertTrue(mx.array_equal(y, mx.array([2, 4, 6])))
|
|
|
|
self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))
|
|
|
|
|
|
|
|
def test_async_eval_twice(self):
|
|
|
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
|
|
|
mx.async_eval(x)
|
|
|
|
y = x + 1
|
|
|
|
mx.async_eval(y)
|
|
|
|
self.assertEqual(x.item(), 3)
|
|
|
|
|
|
|
|
def test_async_eval_in_trace(self):
|
|
|
|
def fun(x):
|
|
|
|
y = x + 1.0
|
|
|
|
mx.async_eval(y)
|
|
|
|
return mx.exp(y)
|
|
|
|
|
|
|
|
# Raises
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.grad(fun)(mx.array(1.0))
|
|
|
|
|
|
|
|
# Also raises
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.vmap(fun)(mx.ones((2, 2)))
|
|
|
|
|
|
|
|
def test_async_eval_into_eval(self):
|
|
|
|
x = mx.array(1)
|
|
|
|
y = x + 1
|
|
|
|
mx.async_eval(y)
|
|
|
|
a = y - 10
|
|
|
|
b = mx.abs(a)
|
|
|
|
self.assertEqual(b.item(), 8)
|
|
|
|
|
|
|
|
def test_async_eval_into_eval_diff_stream(self):
|
|
|
|
s = mx.new_stream(mx.cpu)
|
|
|
|
x = mx.array(0)
|
|
|
|
y = x - 5
|
|
|
|
mx.async_eval(y)
|
|
|
|
z = mx.abs(y, stream=s)
|
|
|
|
self.assertEqual(z.item(), 5)
|
|
|
|
|
|
|
|
def test_eval_slow_fast_multi_stream(self):
|
|
|
|
x = mx.ones((8000,))
|
|
|
|
y = mx.abs(mx.array(-1.0))
|
|
|
|
for _ in range(20):
|
|
|
|
x = x + mx.array(1.0)
|
|
|
|
z = mx.add(x, y, stream=mx.cpu)
|
|
|
|
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
|
|
|
|
|
|
|
# Switch eval order
|
|
|
|
x = mx.ones((8000,))
|
|
|
|
y = mx.abs(mx.array(-1.0))
|
|
|
|
for _ in range(20):
|
|
|
|
x = x + mx.array(1.0)
|
|
|
|
z = mx.add(y, x, stream=mx.cpu)
|
|
|
|
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|