Support more Numpy interfaces for masked_scatter (#2832)

This commit is contained in:
CCYeh
2025-12-02 02:51:02 +01:00
committed by GitHub
parent 6e762fe2e2
commit 8879ee00eb
4 changed files with 34 additions and 10 deletions

View File

@@ -766,7 +766,7 @@ auto mlx_slice_update(
const nb::object& obj,
const ScalarOrArray& v) {
// Can't route to slice update if not slice, tuple, or int
if (src.ndim() == 0 ||
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
!nb::isinstance<nb::int_>(obj))) {
return std::make_pair(false, src);
@@ -888,7 +888,9 @@ auto mlx_slice_update(
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
if (nb::isinstance<mx::array>(obj)) {
if (nb::isinstance<nb::bool_>(obj)) {
return mx::array(nb::cast<bool>(obj), mx::bool_);
} else if (nb::isinstance<mx::array>(obj)) {
auto mask = nb::cast<mx::array>(obj);
if (mask.dtype() == mx::bool_) {
return mask;
@@ -898,6 +900,11 @@ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
if (mask.dtype() == nb::dtype<bool>()) {
return nd_array_to_mlx(mask, mx::bool_);
}
} else if (nb::isinstance<nb::list>(obj)) {
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
if (mask.dtype() == mx::bool_) {
return mask;
}
}
return std::nullopt;
}