Support more Numpy interfaces for masked_scatter (#2832)

This commit is contained in:
CCYeh
2025-12-02 02:51:02 +01:00
committed by GitHub
parent 6e762fe2e2
commit 8879ee00eb
4 changed files with 34 additions and 10 deletions

View File

@@ -1929,8 +1929,27 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(a, anp))
def test_setitem_with_boolean_mask(self):
mask_np = np.zeros((10, 10), dtype=bool)
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
# Python list mask
a = mx.array([1.0, 2.0, 3.0])
mask = [True, False, True]
src = mx.array([5.0, 6.0])
expected = mx.array([5.0, 2.0, 6.0])
a[mask] = src
self.assertTrue(mx.array_equal(a, expected))
# mx.array scalar mask
a = mx.array([1.0, 2.0, 3.0])
mask = mx.array(True)
expected = mx.array([5.0, 5.0, 5.0])
a[mask] = 5.0
self.assertTrue(mx.array_equal(a, expected))
# scalar mask
a = mx.array([1.0, 2.0, 3.0])
mask = True
expected = mx.array([5.0, 5.0, 5.0])
a[mask] = 5.0
self.assertTrue(mx.array_equal(a, expected))
mask_np = np.zeros((1, 10, 10), dtype=bool)
with self.assertRaises(ValueError):