mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Add tile op (#438)
This commit is contained in:
30
mlx/ops.cpp
30
mlx/ops.cpp
@@ -753,6 +753,36 @@ array repeat(const array& arr, int repeats, StreamOrDevice s) {
|
||||
return repeat(flatten(arr, s), repeats, 0, s);
|
||||
}
|
||||
|
||||
array tile(
|
||||
const array& arr,
|
||||
std::vector<int> reps,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto shape = arr.shape();
|
||||
if (reps.size() < shape.size()) {
|
||||
reps.insert(reps.begin(), shape.size() - reps.size(), 1);
|
||||
}
|
||||
if (reps.size() > shape.size()) {
|
||||
shape.insert(shape.begin(), reps.size() - shape.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int> expand_shape;
|
||||
std::vector<int> broad_shape;
|
||||
std::vector<int> final_shape;
|
||||
for (int i = 0; i < shape.size(); i++) {
|
||||
if (reps[i] != 1) {
|
||||
expand_shape.push_back(1);
|
||||
broad_shape.push_back(reps[i]);
|
||||
}
|
||||
expand_shape.push_back(shape[i]);
|
||||
broad_shape.push_back(shape[i]);
|
||||
final_shape.push_back(reps[i] * shape[i]);
|
||||
}
|
||||
|
||||
auto x = reshape(arr, expand_shape, s);
|
||||
x = broadcast_to(x, broad_shape, s);
|
||||
return reshape(x, final_shape, s);
|
||||
}
|
||||
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
|
@@ -218,6 +218,8 @@ array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
||||
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
||||
|
||||
array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
|
||||
|
||||
/** Permutes the dimensions according to the given axes. */
|
||||
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
||||
inline array transpose(
|
||||
|
Reference in New Issue
Block a user