mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -151,9 +151,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
@@ -190,7 +188,7 @@ std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
|
||||
Reference in New Issue
Block a user