mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
fix double type promotion (#1901)
This commit is contained in:
parent
7face5d9fd
commit
28b8079e30
@ -44,7 +44,7 @@ constexpr Dtype type_rules[num_types][num_types] = {
|
|||||||
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64
|
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64
|
||||||
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16
|
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16
|
||||||
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32
|
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32
|
||||||
{float64, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float64
|
{float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, complex64}, // float64
|
||||||
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16
|
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16
|
||||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64
|
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64
|
||||||
};
|
};
|
||||||
|
@ -173,6 +173,16 @@ class TestDouble(mlx_tests.MLXTestCase):
|
|||||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_type_promotion(self):
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
a = mx.array([4, 8], mx.float64)
|
||||||
|
b = mx.array([4, 8], mx.int32)
|
||||||
|
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
c = a + b
|
||||||
|
self.assertEqual(c.dtype, mx.float64)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user