mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix double type promotion (#1901)
This commit is contained in:
@@ -173,6 +173,16 @@ class TestDouble(mlx_tests.MLXTestCase):
|
||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||
)
|
||||
|
||||
def test_type_promotion(self):
|
||||
import mlx.core as mx
|
||||
|
||||
a = mx.array([4, 8], mx.float64)
|
||||
b = mx.array([4, 8], mx.int32)
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
c = a + b
|
||||
self.assertEqual(c.dtype, mx.float64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user