mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add support for repeat (#278)
* add repeat function * fix styling * optimizing repeat * fixed minor issues * not sure why that folder is there xD * fixed now for sure * test repeat not repeat test * Fixed --------- Co-authored-by: Bahaa Eddin tabbakha <bahaa@Bahaas-MacBook-Pro.local>
This commit is contained in:
34
mlx/ops.cpp
34
mlx/ops.cpp
@@ -718,6 +718,40 @@ array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
|
||||
return stack(arrays, 0, s);
|
||||
}
|
||||
|
||||
/** array repeat with axis */
|
||||
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
||||
axis = normalize_axis(axis, arr.ndim());
|
||||
|
||||
if (repeats < 0) {
|
||||
throw std::invalid_argument(
|
||||
"[repeat] Number of repeats cannot be negative");
|
||||
}
|
||||
|
||||
if (repeats == 0) {
|
||||
return array({}, arr.dtype());
|
||||
}
|
||||
|
||||
if (repeats == 1) {
|
||||
return arr;
|
||||
}
|
||||
|
||||
std::vector<int> new_shape(arr.shape());
|
||||
new_shape[axis] *= repeats;
|
||||
|
||||
std::vector<array> repeated_arrays;
|
||||
repeated_arrays.reserve(repeats);
|
||||
|
||||
for (int i = 0; i < repeats; ++i) {
|
||||
repeated_arrays.push_back(expand_dims(arr, -1, s));
|
||||
}
|
||||
array repeated = concatenate(repeated_arrays, axis + 1, s);
|
||||
return reshape(repeated, new_shape, s);
|
||||
}
|
||||
|
||||
array repeat(const array& arr, int repeats, StreamOrDevice s) {
|
||||
return repeat(flatten(arr, s), repeats, 0, s);
|
||||
}
|
||||
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
|
||||
@@ -214,6 +214,10 @@ array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
||||
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
|
||||
/** Repeate an array along an axis. */
|
||||
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
||||
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
||||
|
||||
/** Permutes the dimensions according to the given axes. */
|
||||
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
||||
inline array transpose(
|
||||
|
||||
Reference in New Issue
Block a user