diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index e847017c7..cf6cb39b3 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) { case complex64: return print_complex_constant(os, x); case int8: - return print_int_constant(os, x); + os << static_cast(x.item()); + return; case int16: return print_int_constant(os, x); case int32: @@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) { case int64: return print_int_constant(os, x); case uint8: - return print_int_constant(os, x); + os << static_cast(x.item()); + return; case uint16: return print_int_constant(os, x); case uint32: diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 82773dbf2..bdc7a1bff 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase): expected = fn() self.assertTrue(mx.array_equal(expected, out)) + def test_dtypes(self): + x = mx.array([0, 1, 2, 3]) + dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16] + for dtype in dtypes: + x = x.astype(dtype) + mx.eval(x) + + def fn(x): + return x * 1 + 0 + + out = mx.compile(fn)(x) + expected = fn(x) + self.assertTrue(mx.array_equal(expected, out)) + if __name__ == "__main__": unittest.main()