From 0d302cd25bd762646d880cc3c8c814809849955a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 30 Aug 2024 17:24:35 -0700 Subject: [PATCH] Fix compiel with byte sized constants (#1381) --- mlx/backend/common/compiled.cpp | 6 ++++-- python/tests/test_compile.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index e847017c7..cf6cb39b3 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) { case complex64: return print_complex_constant(os, x); case int8: - return print_int_constant(os, x); + os << static_cast(x.item()); + return; case int16: return print_int_constant(os, x); case int32: @@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) { case int64: return print_int_constant(os, x); case uint8: - return print_int_constant(os, x); + os << static_cast(x.item()); + return; case uint16: return print_int_constant(os, x); case uint32: diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 82773dbf2..bdc7a1bff 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase): expected = fn() self.assertTrue(mx.array_equal(expected, out)) + def test_dtypes(self): + x = mx.array([0, 1, 2, 3]) + dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16] + for dtype in dtypes: + x = x.astype(dtype) + mx.eval(x) + + def fn(x): + return x * 1 + 0 + + out = mx.compile(fn)(x) + expected = fn(x) + self.assertTrue(mx.array_equal(expected, out)) + if __name__ == "__main__": unittest.main()