Remove the kernel arg from get_launch_args (#2437)

This commit is contained in:
Cheng
2025-07-30 11:43:02 +09:00
committed by GitHub
parent 3adba92ebe
commit 254476718b
13 changed files with 83 additions and 125 deletions

View File

@@ -122,37 +122,17 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
int work_per_thread = 1);
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
return get_launch_args(
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core