Remove the kernel arg from get_launch_args (#2437)

This commit is contained in:
Cheng
2025-07-30 11:43:02 +09:00
committed by GitHub
parent 3adba92ebe
commit 254476718b
13 changed files with 83 additions and 125 deletions

View File

@@ -63,12 +63,9 @@ void copy_general_input(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
num_blocks,
block_dims,
in_ptr,
@@ -78,11 +75,9 @@ void copy_general_input(
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
cu::copy_g<InType, OutType, IdxT>,
num_blocks,
block_dims,
in_ptr,