mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fix compiel with byte sized constants (#1381)
This commit is contained in:
@@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
expected = fn()
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
def test_dtypes(self):
|
||||
x = mx.array([0, 1, 2, 3])
|
||||
dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16]
|
||||
for dtype in dtypes:
|
||||
x = x.astype(dtype)
|
||||
mx.eval(x)
|
||||
|
||||
def fn(x):
|
||||
return x * 1 + 0
|
||||
|
||||
out = mx.compile(fn)(x)
|
||||
expected = fn(x)
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user