mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-12 23:39:04 +08:00
[CUDA] Faster general copy (#2873)
This commit is contained in:
@@ -95,11 +95,14 @@ void copy_general_input(
|
||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
|
||||
int work_per_thread = 8;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
if (dim0 >= 4 && dim0 < 8) {
|
||||
work_per_thread = 4;
|
||||
} else if (dim0 < 4) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
@@ -110,7 +113,10 @@ void copy_general_input(
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||
}
|
||||
@@ -127,7 +133,9 @@ void copy_general_input(
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
if (work_per_thread == 8) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
||||
} else if (work_per_thread == 4) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user