mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-26 04:21:17 +08:00
Fix compiel with byte sized constants (#1381)
This commit is contained in:
parent
da691257ec
commit
0d302cd25b
@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
case complex64:
|
case complex64:
|
||||||
return print_complex_constant<complex64_t>(os, x);
|
return print_complex_constant<complex64_t>(os, x);
|
||||||
case int8:
|
case int8:
|
||||||
return print_int_constant<int8_t>(os, x);
|
os << static_cast<int32_t>(x.item<int8_t>());
|
||||||
|
return;
|
||||||
case int16:
|
case int16:
|
||||||
return print_int_constant<int16_t>(os, x);
|
return print_int_constant<int16_t>(os, x);
|
||||||
case int32:
|
case int32:
|
||||||
@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
case int64:
|
case int64:
|
||||||
return print_int_constant<int64_t>(os, x);
|
return print_int_constant<int64_t>(os, x);
|
||||||
case uint8:
|
case uint8:
|
||||||
return print_int_constant<uint8_t>(os, x);
|
os << static_cast<uint32_t>(x.item<uint8_t>());
|
||||||
|
return;
|
||||||
case uint16:
|
case uint16:
|
||||||
return print_int_constant<uint16_t>(os, x);
|
return print_int_constant<uint16_t>(os, x);
|
||||||
case uint32:
|
case uint32:
|
||||||
|
@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
expected = fn()
|
expected = fn()
|
||||||
self.assertTrue(mx.array_equal(expected, out))
|
self.assertTrue(mx.array_equal(expected, out))
|
||||||
|
|
||||||
|
def test_dtypes(self):
|
||||||
|
x = mx.array([0, 1, 2, 3])
|
||||||
|
dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16]
|
||||||
|
for dtype in dtypes:
|
||||||
|
x = x.astype(dtype)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return x * 1 + 0
|
||||||
|
|
||||||
|
out = mx.compile(fn)(x)
|
||||||
|
expected = fn(x)
|
||||||
|
self.assertTrue(mx.array_equal(expected, out))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user