diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index 2272ff325..4f9c4ea7b 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -256,7 +256,7 @@ void copy_general_general( } int size = std::accumulate( - data_shape.begin() - 5, data_shape.end(), 1, std::multiplies()); + data_shape.end() - 5, data_shape.end(), 1, std::multiplies()); for (int i = 0; i < src.size(); i += size) { stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides); stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index e3ef4d88f..f04823c3f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1270,6 +1270,10 @@ class TestArray(mlx_tests.MLXTestCase): x[:, 0] = 1.0 self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3)))) + x = mx.zeros((2, 2, 2, 2, 2, 2)) + x[0, 0] = 1 + self.assertTrue(mx.array_equal(x[0, 0], mx.ones((2, 2, 2, 2)))) + def test_array_at(self): a = mx.array(1) a = a.at[None].add(1)