mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
fix complex reduce + nan propagation in min and max (#2377)
This commit is contained in:
@@ -2,9 +2,6 @@ cuda_skip = {
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_nanpropagation",
|
||||
"TestReduce.test_nanpropagation_complex64",
|
||||
# Block masked matmul NYI
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
# Gather matmul NYI
|
||||
|
@@ -153,7 +153,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)
|
||||
check(x, (1, 3, 5, 7, 9))
|
||||
|
||||
def test_nanpropagation(self):
|
||||
def test_nan_propagation(self):
|
||||
dtypes = [
|
||||
"uint8",
|
||||
"uint16",
|
||||
@@ -179,7 +179,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
ref = getattr(np, op)(x_np, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
||||
def test_nanpropagation_complex64(self):
|
||||
def test_nan_propagation_complex64(self):
|
||||
complex_array_1 = mx.array(
|
||||
[1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64
|
||||
).reshape(2, 2)
|
||||
|
Reference in New Issue
Block a user