Fix compiel with byte sized constants (#1381)

This commit is contained in:
Awni Hannun 2024-08-30 17:24:35 -07:00 committed by GitHub
parent da691257ec
commit 0d302cd25b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 2 deletions

View File

@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) {
case complex64:
return print_complex_constant<complex64_t>(os, x);
case int8:
return print_int_constant<int8_t>(os, x);
os << static_cast<int32_t>(x.item<int8_t>());
return;
case int16:
return print_int_constant<int16_t>(os, x);
case int32:
@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) {
case int64:
return print_int_constant<int64_t>(os, x);
case uint8:
return print_int_constant<uint8_t>(os, x);
os << static_cast<uint32_t>(x.item<uint8_t>());
return;
case uint16:
return print_int_constant<uint16_t>(os, x);
case uint32:

View File

@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase):
expected = fn()
self.assertTrue(mx.array_equal(expected, out))
def test_dtypes(self):
x = mx.array([0, 1, 2, 3])
dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16]
for dtype in dtypes:
x = x.astype(dtype)
mx.eval(x)
def fn(x):
return x * 1 + 0
out = mx.compile(fn)(x)
expected = fn(x)
self.assertTrue(mx.array_equal(expected, out))
if __name__ == "__main__":
unittest.main()