From 873cfa292e78bf9fcf8a886334dee77bf0751adc Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Jun 2025 10:51:09 -0700 Subject: [PATCH] fix copy --- mlx/backend/cuda/copy/copy_general.cu | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 9f50c8a31..2dc08c60a 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -63,25 +63,30 @@ void copy_general( MLX_SWITCH_BOOL(large, LARGE, { using IdxT = std::conditional_t; int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { auto kernel = cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = + get_launch_args(kernel, data_size, shape, out.strides(), large); kernel<<>>( in_ptr, out_ptr, - out.size(), + data_size, const_param(shape), const_param(strides_in), const_param(strides_out)); }); } else { // ndim >= 4 auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = + get_launch_args(kernel, data_size, shape, out.strides(), large); kernel<<>>( in_ptr, out_ptr, - out.size(), + data_size, const_param(shape), const_param(strides_in), const_param(strides_out),