mlx/python/tests/test_compile.py

28 lines
567 B
Python
Raw Normal View History

2024-01-15 06:26:53 +08:00
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestCompile(mlx_tests.MLXTestCase):
def test_simple_compile(self):
def fun(x, y):
return x + y
compiled_fn = mx.compile(fun)
compiled_fn = mx.compile(fun)
x = mx.array(1.0)
y = mx.array(1.0)
2024-01-15 22:08:18 +08:00
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
# Try again
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
2024-01-15 06:26:53 +08:00
if __name__ == "__main__":
unittest.main()