mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
6e762fe2e2
...
6c5785bc2f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c5785bc2f | ||
|
|
8879ee00eb |
@@ -179,8 +179,8 @@ assignments, ``updates`` must provide at least as many elements as there are
|
|||||||
|
|
||||||
Boolean masks follow NumPy semantics:
|
Boolean masks follow NumPy semantics:
|
||||||
|
|
||||||
- The mask shape must match the shape of the axes it indexes exactly. No mask
|
- The mask shape must match the shape of the axes it indexes exactly. The only
|
||||||
broadcasting occurs.
|
exception is a scalar boolean mask, which broadcasts to the full array.
|
||||||
- Any axes not covered by the mask are taken in full.
|
- Any axes not covered by the mask are taken in full.
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
|
|||||||
@@ -3466,10 +3466,8 @@ array masked_scatter(
|
|||||||
if (mask.dtype() != bool_) {
|
if (mask.dtype() != bool_) {
|
||||||
throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
|
throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
|
||||||
}
|
}
|
||||||
if (mask.ndim() == 0) {
|
|
||||||
throw std::invalid_argument(
|
if (mask.ndim() > a.ndim()) {
|
||||||
"[masked_scatter] Scalar masks are not supported.");
|
|
||||||
} else if (mask.ndim() > a.ndim()) {
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[masked_scatter] The mask cannot have more dimensions than the target.");
|
"[masked_scatter] The mask cannot have more dimensions than the target.");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -766,7 +766,7 @@ auto mlx_slice_update(
|
|||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
// Can't route to slice update if not slice, tuple, or int
|
// Can't route to slice update if not slice, tuple, or int
|
||||||
if (src.ndim() == 0 ||
|
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
|
||||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
|
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
|
||||||
!nb::isinstance<nb::int_>(obj))) {
|
!nb::isinstance<nb::int_>(obj))) {
|
||||||
return std::make_pair(false, src);
|
return std::make_pair(false, src);
|
||||||
@@ -888,7 +888,9 @@ auto mlx_slice_update(
|
|||||||
|
|
||||||
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
||||||
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
||||||
if (nb::isinstance<mx::array>(obj)) {
|
if (nb::isinstance<nb::bool_>(obj)) {
|
||||||
|
return mx::array(nb::cast<bool>(obj), mx::bool_);
|
||||||
|
} else if (nb::isinstance<mx::array>(obj)) {
|
||||||
auto mask = nb::cast<mx::array>(obj);
|
auto mask = nb::cast<mx::array>(obj);
|
||||||
if (mask.dtype() == mx::bool_) {
|
if (mask.dtype() == mx::bool_) {
|
||||||
return mask;
|
return mask;
|
||||||
@@ -898,6 +900,11 @@ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
|||||||
if (mask.dtype() == nb::dtype<bool>()) {
|
if (mask.dtype() == nb::dtype<bool>()) {
|
||||||
return nd_array_to_mlx(mask, mx::bool_);
|
return nd_array_to_mlx(mask, mx::bool_);
|
||||||
}
|
}
|
||||||
|
} else if (nb::isinstance<nb::list>(obj)) {
|
||||||
|
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
|
||||||
|
if (mask.dtype() == mx::bool_) {
|
||||||
|
return mask;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1929,8 +1929,27 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.array_equal(a, anp))
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
def test_setitem_with_boolean_mask(self):
|
def test_setitem_with_boolean_mask(self):
|
||||||
mask_np = np.zeros((10, 10), dtype=bool)
|
# Python list mask
|
||||||
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
|
a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
mask = [True, False, True]
|
||||||
|
src = mx.array([5.0, 6.0])
|
||||||
|
expected = mx.array([5.0, 2.0, 6.0])
|
||||||
|
a[mask] = src
|
||||||
|
self.assertTrue(mx.array_equal(a, expected))
|
||||||
|
|
||||||
|
# mx.array scalar mask
|
||||||
|
a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
mask = mx.array(True)
|
||||||
|
expected = mx.array([5.0, 5.0, 5.0])
|
||||||
|
a[mask] = 5.0
|
||||||
|
self.assertTrue(mx.array_equal(a, expected))
|
||||||
|
|
||||||
|
# scalar mask
|
||||||
|
a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
mask = True
|
||||||
|
expected = mx.array([5.0, 5.0, 5.0])
|
||||||
|
a[mask] = 5.0
|
||||||
|
self.assertTrue(mx.array_equal(a, expected))
|
||||||
|
|
||||||
mask_np = np.zeros((1, 10, 10), dtype=bool)
|
mask_np = np.zeros((1, 10, 10), dtype=bool)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|||||||
Reference in New Issue
Block a user