diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst index 2cc335262..669a38f19 100644 --- a/docs/src/usage/indexing.rst +++ b/docs/src/usage/indexing.rst @@ -179,8 +179,8 @@ assignments, ``updates`` must provide at least as many elements as there are Boolean masks follow NumPy semantics: -- The mask shape must match the shape of the axes it indexes exactly. No mask - broadcasting occurs. +- The mask shape must match the shape of the axes it indexes exactly. The only + exception is a scalar boolean mask, which broadcasts to the full array. - Any axes not covered by the mask are taken in full. .. code-block:: shell diff --git a/mlx/ops.cpp b/mlx/ops.cpp index fbe679937..c250ffa5d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3466,10 +3466,8 @@ array masked_scatter( if (mask.dtype() != bool_) { throw std::invalid_argument("[masked_scatter] The mask has to be boolean."); } - if (mask.ndim() == 0) { - throw std::invalid_argument( - "[masked_scatter] Scalar masks are not supported."); - } else if (mask.ndim() > a.ndim()) { + + if (mask.ndim() > a.ndim()) { throw std::invalid_argument( "[masked_scatter] The mask cannot have more dimensions than the target."); } diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 5bd9f5bfb..59b0655d8 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -766,7 +766,7 @@ auto mlx_slice_update( const nb::object& obj, const ScalarOrArray& v) { // Can't route to slice update if not slice, tuple, or int - if (src.ndim() == 0 || + if (src.ndim() == 0 || nb::isinstance(obj) || (!nb::isinstance(obj) && !nb::isinstance(obj) && !nb::isinstance(obj))) { return std::make_pair(false, src); @@ -888,7 +888,9 @@ auto mlx_slice_update( std::optional extract_boolean_mask(const nb::object& obj) { using NDArray = nb::ndarray; - if (nb::isinstance(obj)) { + if (nb::isinstance(obj)) { + return mx::array(nb::cast(obj), mx::bool_); + } else if (nb::isinstance(obj)) { auto mask = nb::cast(obj); if (mask.dtype() == mx::bool_) { return mask; @@ -898,6 +900,11 @@ std::optional extract_boolean_mask(const nb::object& obj) { if (mask.dtype() == nb::dtype()) { return nd_array_to_mlx(mask, mx::bool_); } + } else if (nb::isinstance(obj)) { + auto mask = array_from_list(nb::cast(obj), {}); + if (mask.dtype() == mx::bool_) { + return mask; + } } return std::nullopt; } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 72564f73c..e2aa74f04 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1929,8 +1929,27 @@ class TestArray(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(a, anp)) def test_setitem_with_boolean_mask(self): - mask_np = np.zeros((10, 10), dtype=bool) - mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0 + # Python list mask + a = mx.array([1.0, 2.0, 3.0]) + mask = [True, False, True] + src = mx.array([5.0, 6.0]) + expected = mx.array([5.0, 2.0, 6.0]) + a[mask] = src + self.assertTrue(mx.array_equal(a, expected)) + + # mx.array scalar mask + a = mx.array([1.0, 2.0, 3.0]) + mask = mx.array(True) + expected = mx.array([5.0, 5.0, 5.0]) + a[mask] = 5.0 + self.assertTrue(mx.array_equal(a, expected)) + + # scalar mask + a = mx.array([1.0, 2.0, 3.0]) + mask = True + expected = mx.array([5.0, 5.0, 5.0]) + a[mask] = 5.0 + self.assertTrue(mx.array_equal(a, expected)) mask_np = np.zeros((1, 10, 10), dtype=bool) with self.assertRaises(ValueError):