Fix slice data size (#1394)

* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
This commit is contained in:
Awni Hannun
2024-09-04 19:10:43 -07:00
committed by GitHub
parent 11371fe251
commit 7cca1727af
12 changed files with 129 additions and 39 deletions

View File

@@ -135,7 +135,7 @@ template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
size_t data_size = 1;
size_t no_broadcast_data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
bool is_row_contiguous = true;
@@ -147,11 +147,12 @@ inline auto check_contiguity(
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
no_broadcast_data_size *= shape[i];
}
}
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
return std::make_tuple(
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
}
} // namespace mlx::core