mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 11:16:38 +08:00
Improve repeat using broadcasting and reshape (#318)
This commit is contained in:
parent
930b159885
commit
a020a2d49d
19
mlx/ops.cpp
19
mlx/ops.cpp
@ -735,17 +735,18 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
||||
return arr;
|
||||
}
|
||||
|
||||
std::vector<int> new_shape(arr.shape());
|
||||
new_shape[axis] *= repeats;
|
||||
// Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
|
||||
std::vector<int> shape(arr.shape());
|
||||
shape.insert(shape.begin() + axis + 1, repeats);
|
||||
array out = expand_dims(arr, axis + 1, s);
|
||||
out = broadcast_to(out, shape, s);
|
||||
|
||||
std::vector<array> repeated_arrays;
|
||||
repeated_arrays.reserve(repeats);
|
||||
// Reshape back into a contiguous array where S_axis is now S_axis * repeats
|
||||
shape.erase(shape.begin() + axis + 1);
|
||||
shape[axis] *= repeats;
|
||||
out = reshape(out, shape, s);
|
||||
|
||||
for (int i = 0; i < repeats; ++i) {
|
||||
repeated_arrays.push_back(expand_dims(arr, -1, s));
|
||||
}
|
||||
array repeated = concatenate(repeated_arrays, axis + 1, s);
|
||||
return reshape(repeated, new_shape, s);
|
||||
return out;
|
||||
}
|
||||
|
||||
array repeat(const array& arr, int repeats, StreamOrDevice s) {
|
||||
|
Loading…
Reference in New Issue
Block a user