mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 10:02:12 +08:00
Fix compiel with byte sized constants (#1381)
This commit is contained in:
parent
da691257ec
commit
0d302cd25b
@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
os << static_cast<int32_t>(x.item<int8_t>());
|
||||
return;
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
os << static_cast<uint32_t>(x.item<uint8_t>());
|
||||
return;
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
|
@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
expected = fn()
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
def test_dtypes(self):
|
||||
x = mx.array([0, 1, 2, 3])
|
||||
dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16]
|
||||
for dtype in dtypes:
|
||||
x = x.astype(dtype)
|
||||
mx.eval(x)
|
||||
|
||||
def fn(x):
|
||||
return x * 1 + 0
|
||||
|
||||
out = mx.compile(fn)(x)
|
||||
expected = fn(x)
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user