diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index 5fd0415ba..429fa4202 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -44,7 +44,7 @@ constexpr Dtype type_rules[num_types][num_types] = { {int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64 {float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16 {float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32 - {float64, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float64 + {float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, complex64}, // float64 {bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16 {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64 }; diff --git a/python/tests/test_double.py b/python/tests/test_double.py index 00d8c9639..8de3f3cea 100644 --- a/python/tests/test_double.py +++ b/python/tests/test_double.py @@ -173,6 +173,16 @@ class TestDouble(mlx_tests.MLXTestCase): mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True) ) + def test_type_promotion(self): + import mlx.core as mx + + a = mx.array([4, 8], mx.float64) + b = mx.array([4, 8], mx.int32) + + with mx.stream(mx.cpu): + c = a + b + self.assertEqual(c.dtype, mx.float64) + if __name__ == "__main__": unittest.main()