From fb71a82ada36fa1e83b24f062c296aed165c4f05 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 17 May 2024 21:10:03 -0700 Subject: [PATCH] Fix copy bug with many dims (#1137) --- mlx/backend/common/copy.cpp | 2 +- python/tests/test_array.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index 2272ff325..4f9c4ea7b 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -256,7 +256,7 @@ void copy_general_general( } int size = std::accumulate( - data_shape.begin() - 5, data_shape.end(), 1, std::multiplies()); + data_shape.end() - 5, data_shape.end(), 1, std::multiplies()); for (int i = 0; i < src.size(); i += size) { stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides); stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index e3ef4d88f..f04823c3f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1270,6 +1270,10 @@ class TestArray(mlx_tests.MLXTestCase): x[:, 0] = 1.0 self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3)))) + x = mx.zeros((2, 2, 2, 2, 2, 2)) + x[0, 0] = 1 + self.assertTrue(mx.array_equal(x[0, 0], mx.ones((2, 2, 2, 2)))) + def test_array_at(self): a = mx.array(1) a = a.at[None].add(1)