diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index 42a027ec5..d85f4a67c 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -95,11 +95,14 @@ void copy_general_input( const InType* in_ptr = gpu_ptr(in) + offset_in; OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); - int work_per_thread = 1; + + int work_per_thread = 8; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out.size() / dim0; - if (dim0 >= 4) { + if (dim0 >= 4 && dim0 < 8) { work_per_thread = 4; + } else if (dim0 < 4) { + work_per_thread = 1; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); @@ -110,7 +113,10 @@ void copy_general_input( dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::copy_g_nd; - if (work_per_thread == 4) { + if (work_per_thread == 8) { + kernel = + cu::copy_g_nd; + } else if (work_per_thread == 4) { kernel = cu::copy_g_nd; } @@ -127,7 +133,9 @@ void copy_general_input( }); } else { // ndim >= 4 auto kernel = cu::copy_g; - if (work_per_thread == 4) { + if (work_per_thread == 8) { + kernel = cu::copy_g; + } else if (work_per_thread == 4) { kernel = cu::copy_g; } encoder.add_kernel_node( diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index 1ed6e2345..403d07d92 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -7,8 +7,6 @@ namespace mlx::core { -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s); - void copy_gpu(const array& in, array& out, CopyType ctype) { copy_gpu(in, out, ctype, out.primitive().stream()); }