Compare commits

...

2 Commits

Author SHA1 Message Date
Awni Hannun
6c5785bc2f use thread local cpature mode (#2850)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-01 19:02:47 -08:00
CCYeh
8879ee00eb Support more Numpy interfaces for masked_scatter (#2832) 2025-12-01 17:51:02 -08:00
5 changed files with 35 additions and 11 deletions

View File

@@ -179,8 +179,8 @@ assignments, ``updates`` must provide at least as many elements as there are
Boolean masks follow NumPy semantics: Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. No mask - The mask shape must match the shape of the axes it indexes exactly. The only
broadcasting occurs. exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full. - Any axes not covered by the mask are taken in full.
.. code-block:: shell .. code-block:: shell

View File

@@ -87,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
return; return;
} }
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {

View File

@@ -3466,10 +3466,8 @@ array masked_scatter(
if (mask.dtype() != bool_) { if (mask.dtype() != bool_) {
throw std::invalid_argument("[masked_scatter] The mask has to be boolean."); throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
} }
if (mask.ndim() == 0) {
throw std::invalid_argument( if (mask.ndim() > a.ndim()) {
"[masked_scatter] Scalar masks are not supported.");
} else if (mask.ndim() > a.ndim()) {
throw std::invalid_argument( throw std::invalid_argument(
"[masked_scatter] The mask cannot have more dimensions than the target."); "[masked_scatter] The mask cannot have more dimensions than the target.");
} }

View File

@@ -766,7 +766,7 @@ auto mlx_slice_update(
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
// Can't route to slice update if not slice, tuple, or int // Can't route to slice update if not slice, tuple, or int
if (src.ndim() == 0 || if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) && (!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
!nb::isinstance<nb::int_>(obj))) { !nb::isinstance<nb::int_>(obj))) {
return std::make_pair(false, src); return std::make_pair(false, src);
@@ -888,7 +888,9 @@ auto mlx_slice_update(
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) { std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>; using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
if (nb::isinstance<mx::array>(obj)) { if (nb::isinstance<nb::bool_>(obj)) {
return mx::array(nb::cast<bool>(obj), mx::bool_);
} else if (nb::isinstance<mx::array>(obj)) {
auto mask = nb::cast<mx::array>(obj); auto mask = nb::cast<mx::array>(obj);
if (mask.dtype() == mx::bool_) { if (mask.dtype() == mx::bool_) {
return mask; return mask;
@@ -898,6 +900,11 @@ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
if (mask.dtype() == nb::dtype<bool>()) { if (mask.dtype() == nb::dtype<bool>()) {
return nd_array_to_mlx(mask, mx::bool_); return nd_array_to_mlx(mask, mx::bool_);
} }
} else if (nb::isinstance<nb::list>(obj)) {
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
if (mask.dtype() == mx::bool_) {
return mask;
}
} }
return std::nullopt; return std::nullopt;
} }

View File

@@ -1929,8 +1929,27 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(a, anp)) self.assertTrue(np.array_equal(a, anp))
def test_setitem_with_boolean_mask(self): def test_setitem_with_boolean_mask(self):
mask_np = np.zeros((10, 10), dtype=bool) # Python list mask
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0 a = mx.array([1.0, 2.0, 3.0])
mask = [True, False, True]
src = mx.array([5.0, 6.0])
expected = mx.array([5.0, 2.0, 6.0])
a[mask] = src
self.assertTrue(mx.array_equal(a, expected))
# mx.array scalar mask
a = mx.array([1.0, 2.0, 3.0])
mask = mx.array(True)
expected = mx.array([5.0, 5.0, 5.0])
a[mask] = 5.0
self.assertTrue(mx.array_equal(a, expected))
# scalar mask
a = mx.array([1.0, 2.0, 3.0])
mask = True
expected = mx.array([5.0, 5.0, 5.0])
a[mask] = 5.0
self.assertTrue(mx.array_equal(a, expected))
mask_np = np.zeros((1, 10, 10), dtype=bool) mask_np = np.zeros((1, 10, 10), dtype=bool)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):