mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-14 05:06:39 +08:00
Improve repeat using broadcasting and reshape (#318)
This commit is contained in:
parent
930b159885
commit
a020a2d49d
19
mlx/ops.cpp
19
mlx/ops.cpp
@ -735,17 +735,18 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
|||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> new_shape(arr.shape());
|
// Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
|
||||||
new_shape[axis] *= repeats;
|
std::vector<int> shape(arr.shape());
|
||||||
|
shape.insert(shape.begin() + axis + 1, repeats);
|
||||||
|
array out = expand_dims(arr, axis + 1, s);
|
||||||
|
out = broadcast_to(out, shape, s);
|
||||||
|
|
||||||
std::vector<array> repeated_arrays;
|
// Reshape back into a contiguous array where S_axis is now S_axis * repeats
|
||||||
repeated_arrays.reserve(repeats);
|
shape.erase(shape.begin() + axis + 1);
|
||||||
|
shape[axis] *= repeats;
|
||||||
|
out = reshape(out, shape, s);
|
||||||
|
|
||||||
for (int i = 0; i < repeats; ++i) {
|
return out;
|
||||||
repeated_arrays.push_back(expand_dims(arr, -1, s));
|
|
||||||
}
|
|
||||||
array repeated = concatenate(repeated_arrays, axis + 1, s);
|
|
||||||
return reshape(repeated, new_shape, s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array repeat(const array& arr, int repeats, StreamOrDevice s) {
|
array repeat(const array& arr, int repeats, StreamOrDevice s) {
|
||||||
|
Loading…
Reference in New Issue
Block a user