mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -7,11 +7,11 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/arange.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/threefry.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -22,39 +22,58 @@ void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t compute_dynamic_offset(
|
||||
static std::pair<array, bool> compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes) {
|
||||
auto compute_offset = [&strides, &axes](const auto* indices) {
|
||||
int64_t offset = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
offset += indices[i] * strides[axes[i]];
|
||||
}
|
||||
return offset;
|
||||
};
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(indices);
|
||||
encoder.set_output_array(offset);
|
||||
auto compute_offset =
|
||||
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
|
||||
int64_t offset_ = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
offset_ += indices[i] * strides[axes[i]];
|
||||
}
|
||||
offset[0] = offset_;
|
||||
};
|
||||
switch (indices.dtype()) {
|
||||
case int8:
|
||||
case uint8:
|
||||
return compute_offset(indices.data<uint8_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint8_t>());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
return compute_offset(indices.data<uint16_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint16_t>());
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
return compute_offset(indices.data<uint32_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint32_t>());
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
return compute_offset(indices.data<uint64_t>());
|
||||
encoder.dispatch(compute_offset, indices.data<uint64_t>());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Invalid indices type.");
|
||||
}
|
||||
return {offset, donate};
|
||||
}
|
||||
|
||||
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -104,14 +123,59 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
arange(inputs, out, start_, step_);
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case float64:
|
||||
arange<double>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start_, start_ + step_, out, out.size(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -134,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,7 +209,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General);
|
||||
copy(in, out, CopyType::General, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,14 +233,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -192,7 +249,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy(val, out, CopyType::Scalar);
|
||||
copy(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
@@ -207,7 +264,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
|
||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -223,39 +280,49 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
auto cptr = out.data<char>();
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
// Get ith key
|
||||
auto kidx = 2 * i;
|
||||
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
|
||||
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
|
||||
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(inputs[0]);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([kptr,
|
||||
cptr,
|
||||
bytes_per_key,
|
||||
num_keys,
|
||||
kshape = keys.shape(),
|
||||
kstrides = keys.strides()]() mutable {
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
// Get ith key
|
||||
auto kidx = 2 * i;
|
||||
auto k1_elem = elem_to_loc(kidx, kshape, kstrides);
|
||||
auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides);
|
||||
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
|
||||
|
||||
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||
std::tie(ptr[count.first], ptr[count.second]) =
|
||||
random::threefry2x32_hash(key, count);
|
||||
}
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
|
||||
for (; count.first + 1 < half_size; count.first++, count.second++) {
|
||||
std::tie(ptr[count.first], ptr[count.second]) =
|
||||
random::threefry2x32_hash(key, count);
|
||||
}
|
||||
if (count.first < half_size) {
|
||||
auto rb = random::threefry2x32_hash(key, count);
|
||||
ptr[count.first++] = rb.first;
|
||||
if (bytes_per_key % 4 > 0) {
|
||||
std::copy(
|
||||
reinterpret_cast<char*>(&rb.second),
|
||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||
cptr + 4 * count.second);
|
||||
} else {
|
||||
ptr[count.second] = rb.second;
|
||||
}
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
}
|
||||
}
|
||||
if (!even) {
|
||||
count.second = 0;
|
||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -269,16 +336,23 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
|
||||
auto [in_offset, donated] =
|
||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ i_offset,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream(),
|
||||
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
|
||||
if (!donated) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset));
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_cpu(
|
||||
@@ -296,9 +370,10 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
|
||||
auto [out_offset, donated] =
|
||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
@@ -306,8 +381,14 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
|
||||
/* const std::vector<stride_t>& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ o_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream(),
|
||||
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
if (!donated) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset));
|
||||
}
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -329,7 +410,7 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] =
|
||||
@@ -344,7 +425,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
stream());
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -372,9 +454,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||
} else {
|
||||
copy_inplace(in, tmp, CopyType::General);
|
||||
copy_inplace(in, tmp, CopyType::General, stream());
|
||||
}
|
||||
|
||||
auto flags = out.flags();
|
||||
@@ -382,7 +464,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user