diff --git a/mlx/ops.cpp b/mlx/ops.cpp index eac77735b6..f4f6b922d5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -735,17 +735,18 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { return arr; } - std::vector new_shape(arr.shape()); - new_shape[axis] *= repeats; + // Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...) + std::vector shape(arr.shape()); + shape.insert(shape.begin() + axis + 1, repeats); + array out = expand_dims(arr, axis + 1, s); + out = broadcast_to(out, shape, s); - std::vector repeated_arrays; - repeated_arrays.reserve(repeats); + // Reshape back into a contiguous array where S_axis is now S_axis * repeats + shape.erase(shape.begin() + axis + 1); + shape[axis] *= repeats; + out = reshape(out, shape, s); - for (int i = 0; i < repeats; ++i) { - repeated_arrays.push_back(expand_dims(arr, -1, s)); - } - array repeated = concatenate(repeated_arrays, axis + 1, s); - return reshape(repeated, new_shape, s); + return out; } array repeat(const array& arr, int repeats, StreamOrDevice s) {