mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Fix cpu compile (#934)
* fix one cpu bug, test for another * format hooks * simplify contiguity check for cpu compile * fix * add back donation * comment
This commit is contained in:
@@ -229,14 +229,7 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
@@ -317,28 +310,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
{
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
|
Reference in New Issue
Block a user