mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-20 16:11:14 +08:00
Add a special case when not keeping the dims
This commit is contained in:
parent
a57a75b992
commit
a7faa04cd4
@ -108,6 +108,11 @@ inline void allocate_same_layout(
|
|||||||
array& out,
|
array& out,
|
||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
if (out.ndim() < in.ndim()) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate the transpositions applied to in in order to apply them to out.
|
// Calculate the transpositions applied to in in order to apply them to out.
|
||||||
std::vector<int> axis_order(in.ndim());
|
std::vector<int> axis_order(in.ndim());
|
||||||
std::iota(axis_order.begin(), axis_order.end(), 0);
|
std::iota(axis_order.begin(), axis_order.end(), 0);
|
||||||
|
Loading…
Reference in New Issue
Block a user