Add Masked Scatter (#2663)

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
CCYeh
2025-11-19 23:53:32 +01:00
committed by GitHub
parent 7f4b7e553c
commit b3825ac149
26 changed files with 1099 additions and 51 deletions

View File

@@ -1928,6 +1928,18 @@ class TestArray(mlx_tests.MLXTestCase):
anp[:, idx] = 4
self.assertTrue(np.array_equal(a, anp))
def test_setitem_with_boolean_mask(self):
mask_np = np.zeros((10, 10), dtype=bool)
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
mask_np = np.zeros((1, 10, 10), dtype=bool)
with self.assertRaises(ValueError):
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
mask_np = np.zeros((10, 10, 1), dtype=bool)
with self.assertRaises(ValueError):
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
def test_array_namespace(self):
a = mx.array(1.0)
api = a.__array_namespace__()