mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
implemented Flatten Module (#149)
* implemented flatten op --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
28
mlx/ops.cpp
28
mlx/ops.cpp
@@ -277,6 +277,34 @@ array reshape(
|
||||
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
|
||||
}
|
||||
|
||||
array flatten(
|
||||
const array& a,
|
||||
int start_axis,
|
||||
int end_axis /* = -1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto ndim = static_cast<int>(a.ndim());
|
||||
start_axis += (start_axis < 0 ? ndim : 0);
|
||||
end_axis += (end_axis < 0 ? ndim + 1 : 0);
|
||||
start_axis = std::max(0, start_axis);
|
||||
end_axis = std::min(ndim, end_axis);
|
||||
if (end_axis < start_axis) {
|
||||
throw std::invalid_argument(
|
||||
"[flatten] start_axis must be less than or equal to end_axis");
|
||||
}
|
||||
if (start_axis == end_axis and a.ndim() != 0) {
|
||||
return a;
|
||||
}
|
||||
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis);
|
||||
new_shape.push_back(-1);
|
||||
new_shape.insert(
|
||||
new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end());
|
||||
return reshape(a, new_shape, s);
|
||||
}
|
||||
|
||||
array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return flatten(a, 0, a.ndim() - 1, s);
|
||||
}
|
||||
|
||||
array squeeze(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@@ -123,6 +123,16 @@ array triu(array x, int k, StreamOrDevice s = {});
|
||||
/** Reshape an array to the given shape. */
|
||||
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
||||
|
||||
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||
array flatten(
|
||||
const array& a,
|
||||
int start_axis,
|
||||
int end_axis = -1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Flatten the array to 1D. */
|
||||
array flatten(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Remove singleton dimensions at the given axes. */
|
||||
array squeeze(
|
||||
const array& a,
|
||||
|
@@ -50,9 +50,9 @@ std::vector<int> broadcast_shapes(
|
||||
}
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays) {
|
||||
if (arrays.empty())
|
||||
if (arrays.empty()) {
|
||||
return true;
|
||||
|
||||
}
|
||||
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
|
||||
return (a.shape() == arrays[0].shape());
|
||||
});
|
||||
|
Reference in New Issue
Block a user