mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix compiel with byte sized constants (#1381)
This commit is contained in:
		| @@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) { | ||||
|     case complex64: | ||||
|       return print_complex_constant<complex64_t>(os, x); | ||||
|     case int8: | ||||
|       return print_int_constant<int8_t>(os, x); | ||||
|       os << static_cast<int32_t>(x.item<int8_t>()); | ||||
|       return; | ||||
|     case int16: | ||||
|       return print_int_constant<int16_t>(os, x); | ||||
|     case int32: | ||||
| @@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) { | ||||
|     case int64: | ||||
|       return print_int_constant<int64_t>(os, x); | ||||
|     case uint8: | ||||
|       return print_int_constant<uint8_t>(os, x); | ||||
|       os << static_cast<uint32_t>(x.item<uint8_t>()); | ||||
|       return; | ||||
|     case uint16: | ||||
|       return print_int_constant<uint16_t>(os, x); | ||||
|     case uint32: | ||||
|   | ||||
| @@ -719,6 +719,20 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         expected = fn() | ||||
|         self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|     def test_dtypes(self): | ||||
|         x = mx.array([0, 1, 2, 3]) | ||||
|         dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16] | ||||
|         for dtype in dtypes: | ||||
|             x = x.astype(dtype) | ||||
|             mx.eval(x) | ||||
|  | ||||
|             def fn(x): | ||||
|                 return x * 1 + 0 | ||||
|  | ||||
|             out = mx.compile(fn)(x) | ||||
|             expected = fn(x) | ||||
|             self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun