Optimization for general ND copies (#1421)

This commit is contained in:
Awni Hannun 2024-09-17 17:59:51 -07:00 committed by GitHub
parent 6af5ca35b2
commit 67b6bf530d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 23 deletions

View File

@ -73,9 +73,11 @@ void copy_gpu_inplace(
} }
}; };
auto [shape, strides_in_, strides_out_] = maybe_collapse(); auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size();
bool use_2d = out.data_size() > UINT32_MAX; bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name; std::string kernel_name;
{ {
std::ostringstream kname; std::ostringstream kname;
@ -93,9 +95,13 @@ void copy_gpu_inplace(
kname << "gg"; kname << "gg";
break; break;
} }
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size(); kname << shape.size();
} else if (shape[ndim - 1] >= 4) {
work_per_thread = 4;
kname << "n4";
}
} }
kname << "_copy"; kname << "_copy";
kname << type_to_name(in) << type_to_name(out); kname << type_to_name(in) << type_to_name(out);
@ -115,10 +121,8 @@ void copy_gpu_inplace(
compute_encoder.set_output_array(out, 1, out_offset); compute_encoder.set_output_array(out, 1, out_offset);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
int ndim = shape.size();
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()}; std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()}; std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) { if (ndim > 3) {
set_vector_bytes(compute_encoder, shape, ndim, 2); set_vector_bytes(compute_encoder, shape, ndim, 2);
} }
@ -127,10 +131,6 @@ void copy_gpu_inplace(
set_vector_bytes(compute_encoder, strides_out, ndim, 4); set_vector_bytes(compute_encoder, strides_out, ndim, 4);
} }
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1; int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@ -139,6 +139,11 @@ void copy_gpu_inplace(
data_size *= s; data_size *= s;
int rest = data_size / (dim0 * dim1); int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
}
// NB assuming thread_group_size is a power of 2 larger than 32 x 32 // NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { if (thread_group_size != 1024) {

View File

@ -176,6 +176,8 @@ MTL::ComputePipelineState* get_copy_kernel(
<< get_template_definition( << get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type) "g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type) << get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition( << get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type) "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition( << get_template_definition(
@ -183,7 +185,9 @@ MTL::ComputePipelineState* get_copy_kernel(
<< get_template_definition( << get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type) "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition( << get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type); "gg_" + lib_name, "copy_gg", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);

View File

@ -71,7 +71,7 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U, int N = 1>
[[kernel]] void copy_g( [[kernel]] void copy_g(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@ -80,10 +80,22 @@ template <typename T, typename U>
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto src_idx = elem_to_loc(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
int64_t dst_idx = int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
int64_t dst_idx =
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
src_idx += src_xstride;
}
} }
template <typename T, typename U> template <typename T, typename U>
@ -122,7 +134,7 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U, int N = 1>
[[kernel]] void copy_gg( [[kernel]] void copy_gg(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@ -131,7 +143,22 @@ template <typename T, typename U>
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto idx = elem_to_loc_2_nd(
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); {N * index.x, index.y, index.z},
dst[dst_idx] = static_cast<U>(src[src_idx]); src_shape,
src_strides,
dst_strides,
ndim);
if (N == 1) {
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);
idx.x += src_xstride;
idx.y += dst_xstride;
}
} }

View File

@ -17,7 +17,9 @@
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \ instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@ -149,15 +149,16 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims // Multiple Arrays with generic dims
template <typename stride_t>
METAL_FUNC ulong2 elem_to_loc_2_nd( METAL_FUNC ulong2 elem_to_loc_2_nd(
uint3 elem, uint3 elem,
constant const int* shape, constant const int* shape,
constant const size_t* a_strides, constant const stride_t* a_strides,
constant const size_t* b_strides, constant const stride_t* b_strides,
int ndim) { int ndim) {
ulong2 loc = { ulong2 loc = {
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]}; ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) { for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d]; uint l = elem.z % shape[d];
loc.x += l * a_strides[d]; loc.x += l * a_strides[d];