implemented Flatten Module (#149)

* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
__mo_san__
2023-12-17 06:54:37 +01:00
committed by GitHub
parent eebd7c275d
commit 52e1589a52
8 changed files with 113 additions and 2 deletions

View File

@@ -728,6 +728,21 @@ void init_array(py::module_& m) {
return power(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"flatten",
[](const array& a,
int start_axis,
int end_axis,
const StreamOrDevice& s) {
return flatten(a, start_axis, end_axis);
},
"start_axis"_a = 0,
"end_axis"_a = -1,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
See :func:`flatten`.
)pbdoc")
.def(
"reshape",
[](const array& a, py::args shape, StreamOrDevice s) {