fix large ops (#1620)

This commit is contained in:
Awni Hannun 2024-11-24 09:17:10 -08:00 committed by GitHub
parent bb303c45a5
commit 211411faf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 37 additions and 25 deletions

View File

@ -85,6 +85,7 @@ void binary_op_gpu_inplace(
auto ndim = shape.size();
int work_per_thread;
if (bopt == BinaryOpType::General) {
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;

View File

@ -77,7 +77,7 @@ void copy_gpu_inplace(
bool large;
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
// Allow for negative strides
large = out.data_size() > INT32_MAX;
large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
} else {
large = out.data_size() > UINT32_MAX;
}
@ -134,13 +134,13 @@ void copy_gpu_inplace(
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int rest = data_size / (dim0 * dim1);
size_t rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder.set_bytes(ndim, 5);

View File

@ -69,7 +69,7 @@ template <typename T, typename U, typename Op>
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
@ -77,8 +77,8 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}

View File

@ -19,7 +19,8 @@
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \

View File

@ -90,7 +90,7 @@ template <typename T, typename U, typename Op>
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
@ -99,8 +99,8 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];

View File

@ -17,9 +17,10 @@
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)

View File

@ -36,13 +36,13 @@ template <typename T, typename U>
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
@ -97,15 +97,15 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
}
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}

View File

@ -9,16 +9,18 @@
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \

View File

@ -22,7 +22,7 @@ template <typename T, typename Op>
d[offset] = Op()(a[offset], b[offset], c[offset]);
}
template <typename T, typename Op>
template <typename T, typename Op, typename IdxT = size_t>
[[kernel]] void ternary_g_nd1(
device const bool* a,
device const T* b,
@ -32,9 +32,9 @@ template <typename T, typename Op>
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_strides);
auto c_idx = elem_to_loc_1<size_t, IdxT>(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}

View File

@ -12,9 +12,10 @@
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \

View File

@ -40,6 +40,9 @@ void ternary_op_gpu_inplace(
auto ndim = shape.size();
int work_per_thread;
if (topt == TernaryOpType::General) {
large |=
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
c.data_size() > UINT32_MAX);
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;

View File

@ -36,7 +36,10 @@ void unary_op_gpu_inplace(
auto [shape, strides] = maybe_collapse();
int ndim = shape.size();
size_t nthreads = contig ? in.data_size() : in.size();
bool large = nthreads > UINT32_MAX;
bool large = in.data_size() > UINT32_MAX;
if (!contig) {
large |= in.size() > UINT32_MAX;
}
int work_per_thread = !contig && large ? 4 : 1;
std::string kernel_name;
if (contig) {