Fix compiel with byte sized constants (#1381)

This commit is contained in:
Awni Hannun
2024-08-30 17:24:35 -07:00
committed by GitHub
parent da691257ec
commit 0d302cd25b
2 changed files with 18 additions and 2 deletions

View File

@@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase):
expected = fn()
self.assertTrue(mx.array_equal(expected, out))
def test_dtypes(self):
x = mx.array([0, 1, 2, 3])
dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16]
for dtype in dtypes:
x = x.astype(dtype)
mx.eval(x)
def fn(x):
return x * 1 + 0
out = mx.compile(fn)(x)
expected = fn(x)
self.assertTrue(mx.array_equal(expected, out))
if __name__ == "__main__":
unittest.main()