mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
[CUDA] Faster general copy (#2873)
This commit is contained in:
@@ -95,11 +95,14 @@ void copy_general_input(
|
|||||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
int work_per_thread = 1;
|
|
||||||
|
int work_per_thread = 8;
|
||||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
auto rest = out.size() / dim0;
|
auto rest = out.size() / dim0;
|
||||||
if (dim0 >= 4) {
|
if (dim0 >= 4 && dim0 < 8) {
|
||||||
work_per_thread = 4;
|
work_per_thread = 4;
|
||||||
|
} else if (dim0 < 4) {
|
||||||
|
work_per_thread = 1;
|
||||||
}
|
}
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
@@ -110,7 +113,10 @@ void copy_general_input(
|
|||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||||
if (work_per_thread == 4) {
|
if (work_per_thread == 8) {
|
||||||
|
kernel =
|
||||||
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
||||||
|
} else if (work_per_thread == 4) {
|
||||||
kernel =
|
kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||||
}
|
}
|
||||||
@@ -127,7 +133,9 @@ void copy_general_input(
|
|||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||||
if (work_per_thread == 4) {
|
if (work_per_thread == 8) {
|
||||||
|
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
||||||
|
} else if (work_per_thread == 4) {
|
||||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||||
}
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
|
|||||||
@@ -7,8 +7,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user