mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <katharas@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
|
||||
#include "python/src/convert.h"
|
||||
@@ -885,6 +886,22 @@ auto mlx_slice_update(
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
||||
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
||||
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
auto mask = nb::cast<mx::array>(obj);
|
||||
if (mask.dtype() == mx::bool_) {
|
||||
return mask;
|
||||
}
|
||||
} else if (nb::isinstance<NDArray>(obj)) {
|
||||
auto mask = nb::cast<NDArray>(obj);
|
||||
if (mask.dtype() == nb::dtype<bool>()) {
|
||||
return nd_array_to_mlx(mask, mx::bool_);
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void mlx_set_item(
|
||||
mx::array& src,
|
||||
const nb::object& obj,
|
||||
@@ -895,6 +912,13 @@ void mlx_set_item(
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto mask = extract_boolean_mask(obj)) {
|
||||
auto updates = to_array(v, src.dtype());
|
||||
auto result = masked_scatter(src, *mask, updates);
|
||||
src.overwrite_descriptor(result);
|
||||
return;
|
||||
}
|
||||
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
auto out = scatter(src, indices, updates, axes);
|
||||
|
||||
Reference in New Issue
Block a user