implemented Flatten Module (#149)

* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
__mo_san__ 2023-12-17 06:54:37 +01:00 committed by GitHub
parent eebd7c275d
commit 52e1589a52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 113 additions and 2 deletions

View File

@ -41,6 +41,7 @@ Operations
expand_dims expand_dims
eye eye
floor floor
flatten
full full
greater greater
greater_equal greater_equal

View File

@ -277,6 +277,34 @@ array reshape(
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a}); shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
} }
array flatten(
const array& a,
int start_axis,
int end_axis /* = -1 */,
StreamOrDevice s /* = {} */) {
auto ndim = static_cast<int>(a.ndim());
start_axis += (start_axis < 0 ? ndim : 0);
end_axis += (end_axis < 0 ? ndim + 1 : 0);
start_axis = std::max(0, start_axis);
end_axis = std::min(ndim, end_axis);
if (end_axis < start_axis) {
throw std::invalid_argument(
"[flatten] start_axis must be less than or equal to end_axis");
}
if (start_axis == end_axis and a.ndim() != 0) {
return a;
}
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis);
new_shape.push_back(-1);
new_shape.insert(
new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end());
return reshape(a, new_shape, s);
}
array flatten(const array& a, StreamOrDevice s /* = {} */) {
return flatten(a, 0, a.ndim() - 1, s);
}
array squeeze( array squeeze(
const array& a, const array& a,
const std::vector<int>& axes, const std::vector<int>& axes,

View File

@ -123,6 +123,16 @@ array triu(array x, int k, StreamOrDevice s = {});
/** Reshape an array to the given shape. */ /** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {}); array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
const array& a,
int start_axis,
int end_axis = -1,
StreamOrDevice s = {});
/** Flatten the array to 1D. */
array flatten(const array& a, StreamOrDevice s = {});
/** Remove singleton dimensions at the given axes. */ /** Remove singleton dimensions at the given axes. */
array squeeze( array squeeze(
const array& a, const array& a,

View File

@ -50,9 +50,9 @@ std::vector<int> broadcast_shapes(
} }
bool is_same_shape(const std::vector<array>& arrays) { bool is_same_shape(const std::vector<array>& arrays) {
if (arrays.empty()) if (arrays.empty()) {
return true; return true;
}
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) { return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
return (a.shape() == arrays[0].shape()); return (a.shape() == arrays[0].shape());
}); });

View File

@ -728,6 +728,21 @@ void init_array(py::module_& m) {
return power(a, to_array(v, a.dtype())); return power(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"flatten",
[](const array& a,
int start_axis,
int end_axis,
const StreamOrDevice& s) {
return flatten(a, start_axis, end_axis);
},
"start_axis"_a = 0,
"end_axis"_a = -1,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
See :func:`flatten`.
)pbdoc")
.def( .def(
"reshape", "reshape",
[](const array& a, py::args shape, StreamOrDevice s) { [](const array& a, py::args shape, StreamOrDevice s) {

View File

@ -61,6 +61,33 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: The reshaped array. array: The reshaped array.
)pbdoc"); )pbdoc");
m.def(
"flatten",
[](const array& a,
int start_axis,
int end_axis,
const StreamOrDevice& s) { return flatten(a, start_axis, end_axis); },
"a"_a,
py::pos_only(),
"start_axis"_a = 0,
"end_axis"_a = -1,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
flatten(a: array, /, start_axis: int = 0, end_axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array
Flatten an array.
Args:
a (array): Input array.
start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.
end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The flattened array.
)pbdoc");
m.def( m.def(
"squeeze", "squeeze",
[](const array& a, const IntOrVec& v, const StreamOrDevice& s) { [](const array& a, const IntOrVec& v, const StreamOrDevice& s) {

View File

@ -1426,6 +1426,15 @@ class TestOps(mlx_tests.MLXTestCase):
np_c = np.stack([np_a, np_b], axis=1) np_c = np.stack([np_a, np_b], axis=1)
self.assertTrue(np.array_equal(c, np_c)) self.assertTrue(np.array_equal(c, np_c))
def test_flatten(self):
x = mx.zeros([2, 3, 4])
self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4])
self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4])
self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4])
self.assertEqual(x.flatten().shape, [2 * 3 * 4])
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -58,6 +58,27 @@ TEST_CASE("test reshape") {
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0}); CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
} }
TEST_CASE("test flatten") {
array x = zeros({2, 3, 4});
CHECK_EQ(flatten(x).shape(), std::vector<int>({2 * 3 * 4}));
CHECK_EQ(flatten(x, 1, 1).shape(), std::vector<int>({2, 3, 4}));
CHECK_EQ(flatten(x, 1, 2).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, 3).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, -1).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, -2, -1).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({2 * 3 * 4}));
CHECK_EQ(flatten(x, -4, -1).shape(), std::vector<int>({2 * 3 * 4}));
// Check start > end throws
CHECK_THROWS(flatten(x, 2, 1));
// Check scalar flattens to 1D
x = array(1);
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
CHECK_EQ(flatten(x, 0, 0).shape(), std::vector<int>({1}));
}
TEST_CASE("test squeeze and expand") { TEST_CASE("test squeeze and expand") {
array x = zeros({2, 1, 2, 1, 2, 1}); array x = zeros({2, 1, 2, 1, 2, 1});
CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2}); CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2});