implemented Flatten Module (#149)

* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
__mo_san__
2023-12-17 06:54:37 +01:00
committed by GitHub
parent eebd7c275d
commit 52e1589a52
8 changed files with 113 additions and 2 deletions

View File

@@ -277,6 +277,34 @@ array reshape(
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
}
array flatten(
const array& a,
int start_axis,
int end_axis /* = -1 */,
StreamOrDevice s /* = {} */) {
auto ndim = static_cast<int>(a.ndim());
start_axis += (start_axis < 0 ? ndim : 0);
end_axis += (end_axis < 0 ? ndim + 1 : 0);
start_axis = std::max(0, start_axis);
end_axis = std::min(ndim, end_axis);
if (end_axis < start_axis) {
throw std::invalid_argument(
"[flatten] start_axis must be less than or equal to end_axis");
}
if (start_axis == end_axis and a.ndim() != 0) {
return a;
}
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis);
new_shape.push_back(-1);
new_shape.insert(
new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end());
return reshape(a, new_shape, s);
}
array flatten(const array& a, StreamOrDevice s /* = {} */) {
return flatten(a, 0, a.ndim() - 1, s);
}
array squeeze(
const array& a,
const std::vector<int>& axes,

View File

@@ -123,6 +123,16 @@ array triu(array x, int k, StreamOrDevice s = {});
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
const array& a,
int start_axis,
int end_axis = -1,
StreamOrDevice s = {});
/** Flatten the array to 1D. */
array flatten(const array& a, StreamOrDevice s = {});
/** Remove singleton dimensions at the given axes. */
array squeeze(
const array& a,

View File

@@ -50,9 +50,9 @@ std::vector<int> broadcast_shapes(
}
bool is_same_shape(const std::vector<array>& arrays) {
if (arrays.empty())
if (arrays.empty()) {
return true;
}
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
return (a.shape() == arrays[0].shape());
});