diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 23dae8d51..1c1934b83 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -265,8 +265,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) { // Throw error if unable to compile library if (!mtl_lib) { std::ostringstream msg; - msg << "[metal::Device] Unable to load build metal library from source" - << "\n"; + msg << "[metal::Device] Unable to build metal library from source" << "\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } @@ -285,8 +284,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) { // Throw error if unable to compile library if (!mtl_lib) { std::ostringstream msg; - msg << "[metal::Device] Unable to load build stitched metal library" - << "\n"; + msg << "[metal::Device] Unable to build stitched metal library" << "\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5ef3f4724..8ad7fc425 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -426,7 +426,7 @@ array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) { int ax = axis < 0 ? axis + out_dim : axis; if (ax < 0 || ax >= out_dim) { std::ostringstream msg; - msg << "[expand_dims] Invalid axes " << axis << " for output array with " + msg << "[expand_dims] Invalid axis " << axis << " for output array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -452,7 +452,7 @@ array expand_dims( ax = ax < 0 ? ax + out_ndim : ax; if (ax < 0 || ax >= out_ndim) { std::ostringstream msg; - msg << "[expand_dims] Invalid axes " << ax << " for output array with " + msg << "[expand_dims] Invalid axis " << ax << " for output array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -591,7 +591,6 @@ array slice_update( if (!has_neg_strides && upd_shape == src.shape()) { return astype(update_broadcasted, src.dtype(), s); } - return array( src.shape(), src.dtype(), diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index f89a421f4..384aa7c84 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include @@ -767,6 +766,14 @@ auto mlx_slice_update( (!nb::isinstance(obj) && !nb::isinstance(obj))) { return std::make_pair(false, src); } + if (nb::isinstance(obj)) { + // Can't route to slice update if any arrays are present + for (auto idx : nb::cast(obj)) { + if (nb::isinstance(idx)) { + return std::make_pair(false, src); + } + } + } // Should be able to route to slice update @@ -804,14 +811,6 @@ auto mlx_slice_update( // It must be a tuple auto entries = nb::cast(obj); - // Can't route to slice update if any arrays are present - for (int i = 0; i < entries.size(); i++) { - auto idx = entries[i]; - if (nb::isinstance(idx)) { - return std::make_pair(false, src); - } - } - // Expand ellipses into a series of ':' slices auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); @@ -828,9 +827,19 @@ auto mlx_slice_update( } // Process entries - std::vector upd_expand_dims; - int ax = 0; - for (int i = 0; i < indices.size(); ++i) { + std::vector up_reshape(src.ndim()); + int ax = src.ndim() - 1; + int up_ax = up.ndim() - 1; + for (; ax >= non_none_indices; ax--) { + if (up_ax >= 0) { + up_reshape[ax] = up.shape(up_ax); + up_ax--; + } else { + up_reshape[ax] = 1; + } + } + + for (int i = indices.size() - 1; i >= 0; --i) { auto& pyidx = indices[i]; if (nb::isinstance(pyidx)) { get_slice_params( @@ -839,18 +848,19 @@ auto mlx_slice_update( strides[ax], nb::cast(pyidx), src.shape(ax)); - ax++; + up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1; + ax--; } else if (nb::isinstance(pyidx)) { int st = nb::cast(pyidx); st = (st < 0) ? st + src.shape(ax) : st; starts[ax] = st; stops[ax] = st + 1; - upd_expand_dims.push_back(ax); - ax++; + up_reshape[ax] = 1; + ax--; } } - up = expand_dims(up, upd_expand_dims); + up = reshape(up, std::move(up_reshape)); auto out = slice_update(src, up, starts, stops, strides); return std::make_pair(true, out); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index bc44b7e6d..86478c33b 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1262,6 +1262,14 @@ class TestArray(mlx_tests.MLXTestCase): np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4) ) + x = mx.zeros((2, 3, 4, 5, 3)) + x[..., 0] = 1.0 + self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5)))) + + x = mx.zeros((2, 3, 4, 5, 3)) + x[:, 0] = 1.0 + self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3)))) + def test_array_at(self): a = mx.array(1) a = a.at[None].add(1)