mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
implemented Flatten Module (#149)
* implemented flatten op --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
eebd7c275d
commit
52e1589a52
@ -41,6 +41,7 @@ Operations
|
|||||||
expand_dims
|
expand_dims
|
||||||
eye
|
eye
|
||||||
floor
|
floor
|
||||||
|
flatten
|
||||||
full
|
full
|
||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
|
28
mlx/ops.cpp
28
mlx/ops.cpp
@ -277,6 +277,34 @@ array reshape(
|
|||||||
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
|
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array flatten(
|
||||||
|
const array& a,
|
||||||
|
int start_axis,
|
||||||
|
int end_axis /* = -1 */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto ndim = static_cast<int>(a.ndim());
|
||||||
|
start_axis += (start_axis < 0 ? ndim : 0);
|
||||||
|
end_axis += (end_axis < 0 ? ndim + 1 : 0);
|
||||||
|
start_axis = std::max(0, start_axis);
|
||||||
|
end_axis = std::min(ndim, end_axis);
|
||||||
|
if (end_axis < start_axis) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[flatten] start_axis must be less than or equal to end_axis");
|
||||||
|
}
|
||||||
|
if (start_axis == end_axis and a.ndim() != 0) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis);
|
||||||
|
new_shape.push_back(-1);
|
||||||
|
new_shape.insert(
|
||||||
|
new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end());
|
||||||
|
return reshape(a, new_shape, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
return flatten(a, 0, a.ndim() - 1, s);
|
||||||
|
}
|
||||||
|
|
||||||
array squeeze(
|
array squeeze(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -123,6 +123,16 @@ array triu(array x, int k, StreamOrDevice s = {});
|
|||||||
/** Reshape an array to the given shape. */
|
/** Reshape an array to the given shape. */
|
||||||
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||||
|
array flatten(
|
||||||
|
const array& a,
|
||||||
|
int start_axis,
|
||||||
|
int end_axis = -1,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Flatten the array to 1D. */
|
||||||
|
array flatten(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Remove singleton dimensions at the given axes. */
|
/** Remove singleton dimensions at the given axes. */
|
||||||
array squeeze(
|
array squeeze(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -50,9 +50,9 @@ std::vector<int> broadcast_shapes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool is_same_shape(const std::vector<array>& arrays) {
|
bool is_same_shape(const std::vector<array>& arrays) {
|
||||||
if (arrays.empty())
|
if (arrays.empty()) {
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
|
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
|
||||||
return (a.shape() == arrays[0].shape());
|
return (a.shape() == arrays[0].shape());
|
||||||
});
|
});
|
||||||
|
@ -728,6 +728,21 @@ void init_array(py::module_& m) {
|
|||||||
return power(a, to_array(v, a.dtype()));
|
return power(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"flatten",
|
||||||
|
[](const array& a,
|
||||||
|
int start_axis,
|
||||||
|
int end_axis,
|
||||||
|
const StreamOrDevice& s) {
|
||||||
|
return flatten(a, start_axis, end_axis);
|
||||||
|
},
|
||||||
|
"start_axis"_a = 0,
|
||||||
|
"end_axis"_a = -1,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
See :func:`flatten`.
|
||||||
|
)pbdoc")
|
||||||
.def(
|
.def(
|
||||||
"reshape",
|
"reshape",
|
||||||
[](const array& a, py::args shape, StreamOrDevice s) {
|
[](const array& a, py::args shape, StreamOrDevice s) {
|
||||||
|
@ -61,6 +61,33 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The reshaped array.
|
array: The reshaped array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"flatten",
|
||||||
|
[](const array& a,
|
||||||
|
int start_axis,
|
||||||
|
int end_axis,
|
||||||
|
const StreamOrDevice& s) { return flatten(a, start_axis, end_axis); },
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"start_axis"_a = 0,
|
||||||
|
"end_axis"_a = -1,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
flatten(a: array, /, start_axis: int = 0, end_axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Flatten an array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.
|
||||||
|
end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The flattened array.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"squeeze",
|
"squeeze",
|
||||||
[](const array& a, const IntOrVec& v, const StreamOrDevice& s) {
|
[](const array& a, const IntOrVec& v, const StreamOrDevice& s) {
|
||||||
|
@ -1426,6 +1426,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
np_c = np.stack([np_a, np_b], axis=1)
|
np_c = np.stack([np_a, np_b], axis=1)
|
||||||
self.assertTrue(np.array_equal(c, np_c))
|
self.assertTrue(np.array_equal(c, np_c))
|
||||||
|
|
||||||
|
def test_flatten(self):
|
||||||
|
x = mx.zeros([2, 3, 4])
|
||||||
|
self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4])
|
||||||
|
self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4])
|
||||||
|
self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4])
|
||||||
|
self.assertEqual(x.flatten().shape, [2 * 3 * 4])
|
||||||
|
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
|
||||||
|
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -58,6 +58,27 @@ TEST_CASE("test reshape") {
|
|||||||
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
|
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test flatten") {
|
||||||
|
array x = zeros({2, 3, 4});
|
||||||
|
CHECK_EQ(flatten(x).shape(), std::vector<int>({2 * 3 * 4}));
|
||||||
|
|
||||||
|
CHECK_EQ(flatten(x, 1, 1).shape(), std::vector<int>({2, 3, 4}));
|
||||||
|
CHECK_EQ(flatten(x, 1, 2).shape(), std::vector<int>({2, 3 * 4}));
|
||||||
|
CHECK_EQ(flatten(x, 1, 3).shape(), std::vector<int>({2, 3 * 4}));
|
||||||
|
CHECK_EQ(flatten(x, 1, -1).shape(), std::vector<int>({2, 3 * 4}));
|
||||||
|
CHECK_EQ(flatten(x, -2, -1).shape(), std::vector<int>({2, 3 * 4}));
|
||||||
|
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({2 * 3 * 4}));
|
||||||
|
CHECK_EQ(flatten(x, -4, -1).shape(), std::vector<int>({2 * 3 * 4}));
|
||||||
|
|
||||||
|
// Check start > end throws
|
||||||
|
CHECK_THROWS(flatten(x, 2, 1));
|
||||||
|
|
||||||
|
// Check scalar flattens to 1D
|
||||||
|
x = array(1);
|
||||||
|
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
|
||||||
|
CHECK_EQ(flatten(x, 0, 0).shape(), std::vector<int>({1}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test squeeze and expand") {
|
TEST_CASE("test squeeze and expand") {
|
||||||
array x = zeros({2, 1, 2, 1, 2, 1});
|
array x = zeros({2, 1, 2, 1, 2, 1});
|
||||||
CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2});
|
CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2});
|
||||||
|
Loading…
Reference in New Issue
Block a user