mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Make check more general
This commit is contained in:
parent
a7faa04cd4
commit
d999675cb9
@ -108,11 +108,16 @@ inline void allocate_same_layout(
|
||||
array& out,
|
||||
const array& in,
|
||||
const std::vector<int>& axes) {
|
||||
if (out.ndim() < in.ndim()) {
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
}
|
||||
|
||||
if (out.ndim() < in.ndim()) {
|
||||
throw std::runtime_error(
|
||||
"Reduction without keepdims only supported for row-contiguous inputs");
|
||||
}
|
||||
|
||||
// Calculate the transpositions applied to in in order to apply them to out.
|
||||
std::vector<int> axis_order(in.ndim());
|
||||
std::iota(axis_order.begin(), axis_order.end(), 0);
|
||||
|
Loading…
Reference in New Issue
Block a user