mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Optimization for general ND copies (#1421)
This commit is contained in:
parent
6af5ca35b2
commit
67b6bf530d
@ -73,9 +73,11 @@ void copy_gpu_inplace(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||||
|
int ndim = shape.size();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
int work_per_thread = 1;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
@ -93,9 +95,13 @@ void copy_gpu_inplace(
|
|||||||
kname << "gg";
|
kname << "gg";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
kname << shape.size();
|
kname << shape.size();
|
||||||
|
} else if (shape[ndim - 1] >= 4) {
|
||||||
|
work_per_thread = 4;
|
||||||
|
kname << "n4";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
kname << "_copy";
|
kname << "_copy";
|
||||||
kname << type_to_name(in) << type_to_name(out);
|
kname << type_to_name(in) << type_to_name(out);
|
||||||
@ -115,10 +121,8 @@ void copy_gpu_inplace(
|
|||||||
compute_encoder.set_output_array(out, 1, out_offset);
|
compute_encoder.set_output_array(out, 1, out_offset);
|
||||||
|
|
||||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
int ndim = shape.size();
|
|
||||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||||
|
|
||||||
if (ndim > 3) {
|
if (ndim > 3) {
|
||||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||||
}
|
}
|
||||||
@ -127,10 +131,6 @@ void copy_gpu_inplace(
|
|||||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
|
||||||
@ -139,6 +139,11 @@ void copy_gpu_inplace(
|
|||||||
data_size *= s;
|
data_size *= s;
|
||||||
int rest = data_size / (dim0 * dim1);
|
int rest = data_size / (dim0 * dim1);
|
||||||
|
|
||||||
|
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
|
}
|
||||||
|
|
||||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
|
@ -176,6 +176,8 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||||
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
@ -183,7 +185,9 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gg_" + lib_name, "copy_gg", in_type, out_type);
|
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
}
|
}
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
@ -71,7 +71,7 @@ template <typename T, typename U>
|
|||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, int N = 1>
|
||||||
[[kernel]] void copy_g(
|
[[kernel]] void copy_g(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
@ -80,10 +80,22 @@ template <typename T, typename U>
|
|||||||
constant const int& ndim [[buffer(5)]],
|
constant const int& ndim [[buffer(5)]],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
auto src_idx = elem_to_loc(
|
||||||
|
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
||||||
|
if (N == 1) {
|
||||||
int64_t dst_idx =
|
int64_t dst_idx =
|
||||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto xshape = src_shape[ndim - 1];
|
||||||
|
int64_t dst_idx =
|
||||||
|
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
|
||||||
|
auto src_xstride = src_strides[ndim - 1];
|
||||||
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
|
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
||||||
|
src_idx += src_xstride;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
@ -122,7 +134,7 @@ template <typename T, typename U>
|
|||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, int N = 1>
|
||||||
[[kernel]] void copy_gg(
|
[[kernel]] void copy_gg(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
@ -131,7 +143,22 @@ template <typename T, typename U>
|
|||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
constant const int& ndim [[buffer(5)]],
|
constant const int& ndim [[buffer(5)]],
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
auto idx = elem_to_loc_2_nd(
|
||||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
{N * index.x, index.y, index.z},
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
src_shape,
|
||||||
|
src_strides,
|
||||||
|
dst_strides,
|
||||||
|
ndim);
|
||||||
|
if (N == 1) {
|
||||||
|
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto src_xstride = src_strides[ndim - 1];
|
||||||
|
auto dst_xstride = dst_strides[ndim - 1];
|
||||||
|
auto xshape = src_shape[ndim - 1];
|
||||||
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||||
|
dst[idx.y] = static_cast<U>(src[idx.x]);
|
||||||
|
idx.x += src_xstride;
|
||||||
|
idx.y += dst_xstride;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,9 @@
|
|||||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||||
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
||||||
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
|
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
||||||
|
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \
|
||||||
|
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
||||||
|
|
||||||
#define instantiate_copy_itype(itname, itype) \
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||||
|
@ -149,15 +149,16 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Multiple Arrays with generic dims
|
// Multiple Arrays with generic dims
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
METAL_FUNC ulong2 elem_to_loc_2_nd(
|
METAL_FUNC ulong2 elem_to_loc_2_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const size_t* a_strides,
|
constant const stride_t* a_strides,
|
||||||
constant const size_t* b_strides,
|
constant const stride_t* b_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
ulong2 loc = {
|
ulong2 loc = {
|
||||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]};
|
ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
uint l = elem.z % shape[d];
|
uint l = elem.z % shape[d];
|
||||||
loc.x += l * a_strides[d];
|
loc.x += l * a_strides[d];
|
||||||
|
Loading…
Reference in New Issue
Block a user