mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix large ops (#1620)
This commit is contained in:
parent
bb303c45a5
commit
211411faf2
@ -85,6 +85,7 @@ void binary_op_gpu_inplace(
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (bopt == BinaryOpType::General) {
|
||||
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
work_per_thread = 1;
|
||||
|
@ -77,7 +77,7 @@ void copy_gpu_inplace(
|
||||
bool large;
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
// Allow for negative strides
|
||||
large = out.data_size() > INT32_MAX;
|
||||
large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
}
|
||||
@ -134,13 +134,13 @@ void copy_gpu_inplace(
|
||||
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
size_t rest = data_size / (dim0 * dim1);
|
||||
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
|
@ -69,7 +69,7 @@ template <typename T, typename U, typename Op>
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@ -77,8 +77,8 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,8 @@
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
|
@ -90,7 +90,7 @@ template <typename T, typename U, typename Op>
|
||||
d[offset] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
@ -99,8 +99,8 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
|
@ -17,9 +17,10 @@
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
|
@ -36,13 +36,13 @@ template <typename T, typename U>
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@ -97,15 +97,15 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
|
@ -9,16 +9,18 @@
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
|
@ -22,7 +22,7 @@ template <typename T, typename Op>
|
||||
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename Op, typename IdxT = size_t>
|
||||
[[kernel]] void ternary_g_nd1(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
@ -32,9 +32,9 @@ template <typename T, typename Op>
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
|
||||
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1<size_t, IdxT>(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
|
@ -12,9 +12,10 @@
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
@ -40,6 +40,9 @@ void ternary_op_gpu_inplace(
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (topt == TernaryOpType::General) {
|
||||
large |=
|
||||
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX);
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
work_per_thread = 1;
|
||||
|
@ -36,7 +36,10 @@ void unary_op_gpu_inplace(
|
||||
auto [shape, strides] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
bool large = nthreads > UINT32_MAX;
|
||||
bool large = in.data_size() > UINT32_MAX;
|
||||
if (!contig) {
|
||||
large |= in.size() > UINT32_MAX;
|
||||
}
|
||||
int work_per_thread = !contig && large ? 4 : 1;
|
||||
std::string kernel_name;
|
||||
if (contig) {
|
||||
|
Loading…
Reference in New Issue
Block a user