mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 01:46:37 +08:00
Remove the hack around SmallVector in cpu compile (#2494)
This commit is contained in:
parent
4abb218d21
commit
888b13ed63
@ -157,10 +157,12 @@ inline void build_kernel(
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
os << "void " << kernel_name
|
||||
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(i)) {
|
||||
@ -175,8 +177,8 @@ inline void build_kernel(
|
||||
<< "];" << std::endl;
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
os << " const int64_t* " << xname << "_strides = strides["
|
||||
<< strides_index++ << "];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,10 +188,8 @@ inline void build_kernel(
|
||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||
} else {
|
||||
// Add output size
|
||||
if (contiguous) {
|
||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
|
||||
@ -288,17 +288,8 @@ void Compiled::eval_cpu(
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Force allocating shape/strides on heap so we can take their data() first
|
||||
// and then std::move them.
|
||||
// TODO: Refactor code to avoid heap allocation.
|
||||
shape.grow();
|
||||
for (auto& s : strides) {
|
||||
s.grow();
|
||||
}
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
@ -306,9 +297,6 @@ void Compiled::eval_cpu(
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
@ -343,16 +331,20 @@ void Compiled::eval_cpu(
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
if (contiguous) {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
shape = std::move(shape)]() mutable {
|
||||
SmallVector<int64_t*> strides_ptrs;
|
||||
for (auto& s : strides) {
|
||||
strides_ptrs.push_back(s.data());
|
||||
}
|
||||
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -440,6 +440,7 @@ class SmallVector {
|
||||
end_ = begin_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
||||
// TODO: Move to private after removing external code using this method.
|
||||
MLX_NOINLINE void grow(size_t min_capacity = 0) {
|
||||
@ -469,7 +470,6 @@ class SmallVector {
|
||||
end_of_storage_ = new_storage + new_capacity;
|
||||
}
|
||||
|
||||
private:
|
||||
MLX_NOINLINE void free_storage() {
|
||||
std::destroy_n(begin_, end_ - begin_);
|
||||
if (is_big()) {
|
||||
|
Loading…
Reference in New Issue
Block a user