diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 8aa296619..a21819ac0 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -157,10 +157,12 @@ inline void build_kernel( #endif // Start the kernel - os << "void " << kernel_name << "(void** args) {" << std::endl; + os << "void " << kernel_name + << "(int* shape, int64_t** strides, void** args) {" << std::endl; // Add the input arguments int cnt = 0; + int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list if (is_constant(i)) { @@ -175,8 +177,8 @@ inline void build_kernel( << "];" << std::endl; // Scalars and contiguous need no strides if (!is_scalar(x) && !contiguous) { - os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++ - << "];" << std::endl; + os << " const int64_t* " << xname << "_strides = strides[" + << strides_index++ << "];" << std::endl; } } @@ -186,10 +188,8 @@ inline void build_kernel( os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr << "*)args[" << cnt++ << "];" << std::endl; } - // Add output strides and shape to extract the indices. - if (!contiguous) { - os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl; - } else { + // Add output size + if (contiguous) { os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; } @@ -288,17 +288,8 @@ void Compiled::eval_cpu( auto [contiguous, shape, strides] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); - // Force allocating shape/strides on heap so we can take their data() first - // and then std::move them. - // TODO: Refactor code to avoid heap allocation. - shape.grow(); - for (auto& s : strides) { - s.grow(); - } - // Collect function input arguments. std::vector args; - int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { if (is_constant_(i)) { continue; @@ -306,9 +297,6 @@ void Compiled::eval_cpu( const auto& x = inputs[i]; encoder.set_input_array(x); args.push_back((void*)x.data()); - if (!contiguous && !is_scalar(x)) { - args.push_back(strides[strides_index++].data()); - } } // Get the kernel name from the lib @@ -343,16 +331,20 @@ void Compiled::eval_cpu( args.push_back(x.data()); encoder.set_output_array(x); } - if (!contiguous) { - args.push_back((void*)shape.data()); - } else { + if (contiguous) { args.push_back((void*)outputs[0].data_size()); } - auto fun = (void (*)(void**))fn_ptr; + auto fun = reinterpret_cast(fn_ptr); encoder.dispatch([fun, args = std::move(args), strides = std::move(strides), - shape = std::move(shape)]() mutable { fun(args.data()); }); + shape = std::move(shape)]() mutable { + SmallVector strides_ptrs; + for (auto& s : strides) { + strides_ptrs.push_back(s.data()); + } + fun(shape.data(), strides_ptrs.data(), args.data()); + }); } } // namespace mlx::core diff --git a/mlx/small_vector.h b/mlx/small_vector.h index fc4c1f06c..0a3371058 100644 --- a/mlx/small_vector.h +++ b/mlx/small_vector.h @@ -440,6 +440,7 @@ class SmallVector { end_ = begin_; } + private: // Grows the backing store by a factor of two, and at least to {min_capacity}. // TODO: Move to private after removing external code using this method. MLX_NOINLINE void grow(size_t min_capacity = 0) { @@ -469,7 +470,6 @@ class SmallVector { end_of_storage_ = new_storage + new_capacity; } - private: MLX_NOINLINE void free_storage() { std::destroy_n(begin_, end_ - begin_); if (is_big()) {