mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove the kernel arg from get_launch_args (#2437)
This commit is contained in:
@@ -71,12 +71,10 @@ void copy_general(
|
||||
data_size *= s;
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
auto kernel =
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, data_size, shape, out.strides(), large());
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(data_size, shape, out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
@@ -87,11 +85,10 @@ void copy_general(
|
||||
const_param<ndim_constant()>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, data_size, shape, out.strides(), large());
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(data_size, shape, out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
cu::copy_gg<InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
|
||||
Reference in New Issue
Block a user