Remove the kernel arg from get_launch_args (#2437)

This commit is contained in:
Cheng
2025-07-30 11:43:02 +09:00
committed by GitHub
parent 3adba92ebe
commit 254476718b
13 changed files with 83 additions and 125 deletions

View File

@@ -227,16 +227,15 @@ void binary_two_op_gpu_inplace(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
get_launch_args(out_a, large());
encoder.add_kernel_node(
kernel,
cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims,
a.data<InType>(),
@@ -249,11 +248,10 @@ void binary_two_op_gpu_inplace(
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
get_launch_args(out_a, large());
encoder.add_kernel_node(
kernel,
cu::binary_two_g<Op, InType, OutType, IdxT>,
num_blocks,
block_dims,
a.data<InType>(),
@@ -280,7 +278,6 @@ void binary_two_op_gpu_inplace(
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),