Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -151,9 +151,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, Strides> Reshape::prepare_reshape(
const array& in,
const array& out) {
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
@@ -190,7 +188,7 @@ std::pair<bool, Strides> Reshape::prepare_reshape(
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out) {