mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix scatter index bug (#1122)
This commit is contained in:
parent
56a4eaed72
commit
631dfbe673
@ -540,12 +540,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
// Get slice size
|
||||
int slice_size = (end - start);
|
||||
|
||||
// Broadcast update to slide size
|
||||
// Broadcast update to slice size
|
||||
std::vector<int> up_shape_broadcast = {1, slice_size};
|
||||
up_shape_broadcast.insert(
|
||||
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
|
||||
|
||||
up = broadcast_to(update, up_shape_broadcast);
|
||||
up = broadcast_to(up, up_shape_broadcast);
|
||||
|
||||
auto indices = std::vector<array>{idx};
|
||||
auto axes = std::vector<int>{0};
|
||||
|
@ -1318,6 +1318,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a = a.at[idx_x, :, 0].minimum(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
|
||||
|
||||
update = mx.array([1.0, 2.0])[None, None, None]
|
||||
src = mx.array([1.0, 2.0])[None, :]
|
||||
src = src.at[0:1].add(update)
|
||||
self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]])))
|
||||
|
||||
def test_slice_negative_step(self):
|
||||
a_np = np.arange(20)
|
||||
a_mx = mx.array(a_np)
|
||||
|
Loading…
Reference in New Issue
Block a user