mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <katharas@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1928,6 +1928,18 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
anp[:, idx] = 4
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
def test_setitem_with_boolean_mask(self):
|
||||
mask_np = np.zeros((10, 10), dtype=bool)
|
||||
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
|
||||
|
||||
mask_np = np.zeros((1, 10, 10), dtype=bool)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
|
||||
|
||||
mask_np = np.zeros((10, 10, 1), dtype=bool)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
|
||||
|
||||
def test_array_namespace(self):
|
||||
a = mx.array(1.0)
|
||||
api = a.__array_namespace__()
|
||||
|
||||
Reference in New Issue
Block a user