mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Align mlx::core::max op nan propagation with NumPy (#2339)
* Make max op NaN propagation rules align with numpy * Adding benchmarks and testing for max op nanpropagation * Pre-commit formatting * Fix max complex64 nan propagation and add test * Improve the cpp unittest * Only check nans on non-integral types in simd_reduce_impl. * Cleanup using namespace alias * Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16. * Make the max nanpropagation test more meaningful for integer types * Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR.
This commit is contained in:
		@@ -153,6 +153,63 @@ class TestReduce(mlx_tests.MLXTestCase):
 | 
			
		||||
        x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)
 | 
			
		||||
        check(x, (1, 3, 5, 7, 9))
 | 
			
		||||
 | 
			
		||||
    def test_nanpropagation(self):
 | 
			
		||||
        dtypes = [
 | 
			
		||||
            "uint8",
 | 
			
		||||
            "uint16",
 | 
			
		||||
            "uint32",
 | 
			
		||||
            "int8",
 | 
			
		||||
            "int16",
 | 
			
		||||
            "int32",
 | 
			
		||||
            "float16",
 | 
			
		||||
            "float32",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for dtype in dtypes:
 | 
			
		||||
            with self.subTest(dtype=dtype):
 | 
			
		||||
                x = (mx.random.normal((4, 4)) * 10).astype(getattr(mx, dtype))
 | 
			
		||||
                indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2)
 | 
			
		||||
                for idx in indices:
 | 
			
		||||
                    x[idx[0], idx[1]] = mx.nan
 | 
			
		||||
                x_np = np.array(x)
 | 
			
		||||
 | 
			
		||||
                for op in ["max"]:
 | 
			
		||||
                    for axis in [0, 1]:
 | 
			
		||||
                        out = getattr(mx, op)(x, axis=axis)
 | 
			
		||||
                        ref = getattr(np, op)(x_np, axis=axis)
 | 
			
		||||
                        self.assertTrue(np.array_equal(out, ref, equal_nan=True))
 | 
			
		||||
 | 
			
		||||
    def test_nanpropagation_complex64(self):
 | 
			
		||||
        complex_array_1 = mx.array(
 | 
			
		||||
            [1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64
 | 
			
		||||
        ).reshape(2, 2)
 | 
			
		||||
        complex_array_2 = mx.array(
 | 
			
		||||
            [1 + 1j, 2 + 2j, 3 + mx.nan * 1j, 4 + 4j], dtype=mx.complex64
 | 
			
		||||
        ).reshape(2, 2)
 | 
			
		||||
        complex_array_3 = mx.array(
 | 
			
		||||
            [1 + 1j, 2 + mx.nan * 1j, 3 + 3j, 4 + 4j], dtype=mx.complex64
 | 
			
		||||
        ).reshape(2, 2)
 | 
			
		||||
        complex_array_4 = mx.array(
 | 
			
		||||
            [mx.nan + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=mx.complex64
 | 
			
		||||
        ).reshape(2, 2)
 | 
			
		||||
 | 
			
		||||
        np_arrays = [
 | 
			
		||||
            np.array(complex_array_1),
 | 
			
		||||
            np.array(complex_array_2),
 | 
			
		||||
            np.array(complex_array_3),
 | 
			
		||||
            np.array(complex_array_4),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for mx_arr, np_arr in zip(
 | 
			
		||||
            [complex_array_1, complex_array_2, complex_array_3, complex_array_4],
 | 
			
		||||
            np_arrays,
 | 
			
		||||
        ):
 | 
			
		||||
            for axis in [0, 1]:
 | 
			
		||||
                for op in ["max"]:
 | 
			
		||||
                    out = getattr(mx, op)(mx_arr, axis=axis)
 | 
			
		||||
                    ref = getattr(np, op)(np_arr, axis=axis)
 | 
			
		||||
                    self.assertTrue(np.array_equal(out, ref, equal_nan=True))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    mlx_tests.MLXTestRunner(failfast=True)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user