mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix cpu sort
This commit is contained in:
@@ -334,7 +334,9 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||||
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
copy_cpu(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
@@ -426,7 +428,9 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||||
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
copy_cpu(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
|||||||
Reference in New Issue
Block a user