mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Remove the hack around SmallVector in cpu compile (#2494)
This commit is contained in:
parent
4abb218d21
commit
888b13ed63
@ -157,10 +157,12 @@ inline void build_kernel(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Start the kernel
|
// Start the kernel
|
||||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
os << "void " << kernel_name
|
||||||
|
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
int strides_index = 1;
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(i)) {
|
if (is_constant(i)) {
|
||||||
@ -175,8 +177,8 @@ inline void build_kernel(
|
|||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (!is_scalar(x) && !contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
os << " const int64_t* " << xname << "_strides = strides["
|
||||||
<< "];" << std::endl;
|
<< strides_index++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,10 +188,8 @@ inline void build_kernel(
|
|||||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output size
|
||||||
if (!contiguous) {
|
if (contiguous) {
|
||||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
|
||||||
} else {
|
|
||||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,17 +288,8 @@ void Compiled::eval_cpu(
|
|||||||
auto [contiguous, shape, strides] =
|
auto [contiguous, shape, strides] =
|
||||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
// Force allocating shape/strides on heap so we can take their data() first
|
|
||||||
// and then std::move them.
|
|
||||||
// TODO: Refactor code to avoid heap allocation.
|
|
||||||
shape.grow();
|
|
||||||
for (auto& s : strides) {
|
|
||||||
s.grow();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect function input arguments.
|
// Collect function input arguments.
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
int strides_index = 1;
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (is_constant_(i)) {
|
if (is_constant_(i)) {
|
||||||
continue;
|
continue;
|
||||||
@ -306,9 +297,6 @@ void Compiled::eval_cpu(
|
|||||||
const auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
if (!contiguous && !is_scalar(x)) {
|
|
||||||
args.push_back(strides[strides_index++].data());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
@ -343,16 +331,20 @@ void Compiled::eval_cpu(
|
|||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
if (!contiguous) {
|
if (contiguous) {
|
||||||
args.push_back((void*)shape.data());
|
|
||||||
} else {
|
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||||
encoder.dispatch([fun,
|
encoder.dispatch([fun,
|
||||||
args = std::move(args),
|
args = std::move(args),
|
||||||
strides = std::move(strides),
|
strides = std::move(strides),
|
||||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
shape = std::move(shape)]() mutable {
|
||||||
|
SmallVector<int64_t*> strides_ptrs;
|
||||||
|
for (auto& s : strides) {
|
||||||
|
strides_ptrs.push_back(s.data());
|
||||||
|
}
|
||||||
|
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -440,6 +440,7 @@ class SmallVector {
|
|||||||
end_ = begin_;
|
end_ = begin_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
||||||
// TODO: Move to private after removing external code using this method.
|
// TODO: Move to private after removing external code using this method.
|
||||||
MLX_NOINLINE void grow(size_t min_capacity = 0) {
|
MLX_NOINLINE void grow(size_t min_capacity = 0) {
|
||||||
@ -469,7 +470,6 @@ class SmallVector {
|
|||||||
end_of_storage_ = new_storage + new_capacity;
|
end_of_storage_ = new_storage + new_capacity;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
MLX_NOINLINE void free_storage() {
|
MLX_NOINLINE void free_storage() {
|
||||||
std::destroy_n(begin_, end_ - begin_);
|
std::destroy_n(begin_, end_ - begin_);
|
||||||
if (is_big()) {
|
if (is_big()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user