fix double type promotion (#1901)

This commit is contained in:
Awni Hannun 2025-02-25 06:00:53 -08:00 committed by GitHub
parent 7face5d9fd
commit 28b8079e30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 1 deletions

View File

@ -44,7 +44,7 @@ constexpr Dtype type_rules[num_types][num_types] = {
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32
{float64, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float64
{float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, complex64}, // float64
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64
};

View File

@ -173,6 +173,16 @@ class TestDouble(mlx_tests.MLXTestCase):
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
)
def test_type_promotion(self):
import mlx.core as mx
a = mx.array([4, 8], mx.float64)
b = mx.array([4, 8], mx.int32)
with mx.stream(mx.cpu):
c = a + b
self.assertEqual(c.dtype, mx.float64)
if __name__ == "__main__":
unittest.main()