Add Masked Scatter (#2663)

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
CCYeh
2025-11-19 23:53:32 +01:00
committed by GitHub
parent 7f4b7e553c
commit b3825ac149
26 changed files with 1099 additions and 51 deletions

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <numeric>
#include <optional>
#include <sstream>
#include "python/src/convert.h"
@@ -885,6 +886,22 @@ auto mlx_slice_update(
return std::make_pair(true, out);
}
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
if (nb::isinstance<mx::array>(obj)) {
auto mask = nb::cast<mx::array>(obj);
if (mask.dtype() == mx::bool_) {
return mask;
}
} else if (nb::isinstance<NDArray>(obj)) {
auto mask = nb::cast<NDArray>(obj);
if (mask.dtype() == nb::dtype<bool>()) {
return nd_array_to_mlx(mask, mx::bool_);
}
}
return std::nullopt;
}
void mlx_set_item(
mx::array& src,
const nb::object& obj,
@@ -895,6 +912,13 @@ void mlx_set_item(
return;
}
if (auto mask = extract_boolean_mask(obj)) {
auto updates = to_array(v, src.dtype());
auto result = masked_scatter(src, *mask, updates);
src.overwrite_descriptor(result);
return;
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
auto out = scatter(src, indices, updates, axes);