mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
@@ -56,7 +56,7 @@ void broadcast(const array& in, array& out) {
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -69,7 +69,7 @@ void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
@@ -78,7 +78,7 @@ void CustomTransforms::eval(
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
move_or_copy(inputs[j], outputs[i]);
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ void Depends::eval(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
move_or_copy(inputs[i], outputs[i]);
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -210,7 +210,7 @@ void shared_buffer_reshape(
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -276,12 +276,12 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -315,7 +315,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user