mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
fix large ops (#1620)
This commit is contained in:
parent
bb303c45a5
commit
211411faf2
@ -85,6 +85,7 @@ void binary_op_gpu_inplace(
|
|||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
int work_per_thread;
|
int work_per_thread;
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
|
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
|
||||||
work_per_thread = large ? 4 : 2;
|
work_per_thread = large ? 4 : 2;
|
||||||
} else {
|
} else {
|
||||||
work_per_thread = 1;
|
work_per_thread = 1;
|
||||||
|
@ -77,7 +77,7 @@ void copy_gpu_inplace(
|
|||||||
bool large;
|
bool large;
|
||||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
// Allow for negative strides
|
// Allow for negative strides
|
||||||
large = out.data_size() > INT32_MAX;
|
large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
} else {
|
} else {
|
||||||
large = out.data_size() > UINT32_MAX;
|
large = out.data_size() > UINT32_MAX;
|
||||||
}
|
}
|
||||||
@ -134,13 +134,13 @@ void copy_gpu_inplace(
|
|||||||
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
|
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
|
||||||
size_t data_size = 1;
|
size_t data_size = 1;
|
||||||
for (auto& s : shape)
|
for (auto& s : shape)
|
||||||
data_size *= s;
|
data_size *= s;
|
||||||
int rest = data_size / (dim0 * dim1);
|
size_t rest = data_size / (dim0 * dim1);
|
||||||
|
|
||||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
compute_encoder.set_bytes(ndim, 5);
|
compute_encoder.set_bytes(ndim, 5);
|
||||||
|
@ -69,7 +69,7 @@ template <typename T, typename U, typename Op>
|
|||||||
c[offset] = Op()(a[offset], b[offset]);
|
c[offset] = Op()(a[offset], b[offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g_nd1(
|
[[kernel]] void binary_g_nd1(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -77,8 +77,8 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t& a_stride,
|
constant const size_t& a_stride,
|
||||||
constant const size_t& b_stride,
|
constant const size_t& b_stride,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,7 +19,8 @@
|
|||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||||
|
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||||
|
@ -90,7 +90,7 @@ template <typename T, typename U, typename Op>
|
|||||||
d[offset] = out[1];
|
d[offset] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void binary_g_nd1(
|
[[kernel]] void binary_g_nd1(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -99,8 +99,8 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t& a_stride,
|
constant const size_t& a_stride,
|
||||||
constant const size_t& b_stride,
|
constant const size_t& b_stride,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
|
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_stride);
|
||||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
|
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_stride);
|
||||||
auto out = Op()(a[a_idx], b[b_idx]);
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
c[index] = out[0];
|
c[index] = out[0];
|
||||||
d[index] = out[1];
|
d[index] = out[1];
|
||||||
|
@ -17,9 +17,10 @@
|
|||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||||
|
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||||
|
|
||||||
|
@ -36,13 +36,13 @@ template <typename T, typename U>
|
|||||||
dst[offset] = static_cast<U>(src[offset]);
|
dst[offset] = static_cast<U>(src[offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_g_nd1(
|
[[kernel]] void copy_g_nd1(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant const int64_t& src_stride [[buffer(3)]],
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||||
dst[index] = static_cast<U>(src[src_idx]);
|
dst[index] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,15 +97,15 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, typename IdxT = int64_t>
|
||||||
[[kernel]] void copy_gg_nd1(
|
[[kernel]] void copy_gg_nd1(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant const int64_t& src_stride [[buffer(3)]],
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
constant const int64_t& dst_stride [[buffer(4)]],
|
constant const int64_t& dst_stride [[buffer(4)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||||
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
|
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,16 +9,18 @@
|
|||||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
|
||||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||||
|
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||||
|
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||||
|
@ -22,7 +22,7 @@ template <typename T, typename Op>
|
|||||||
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op, typename IdxT = size_t>
|
||||||
[[kernel]] void ternary_g_nd1(
|
[[kernel]] void ternary_g_nd1(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
@ -32,9 +32,9 @@ template <typename T, typename Op>
|
|||||||
constant const size_t& b_strides,
|
constant const size_t& b_strides,
|
||||||
constant const size_t& c_strides,
|
constant const size_t& c_strides,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_strides);
|
auto a_idx = elem_to_loc_1<size_t, IdxT>(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_strides);
|
auto b_idx = elem_to_loc_1<size_t, IdxT>(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_1<size_t, uint>(index, c_strides);
|
auto c_idx = elem_to_loc_1<size_t, IdxT>(index, c_strides);
|
||||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,9 +12,10 @@
|
|||||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
|
||||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||||
|
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||||
|
@ -40,6 +40,9 @@ void ternary_op_gpu_inplace(
|
|||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
int work_per_thread;
|
int work_per_thread;
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
|
large |=
|
||||||
|
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||||
|
c.data_size() > UINT32_MAX);
|
||||||
work_per_thread = large ? 4 : 2;
|
work_per_thread = large ? 4 : 2;
|
||||||
} else {
|
} else {
|
||||||
work_per_thread = 1;
|
work_per_thread = 1;
|
||||||
|
@ -36,7 +36,10 @@ void unary_op_gpu_inplace(
|
|||||||
auto [shape, strides] = maybe_collapse();
|
auto [shape, strides] = maybe_collapse();
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
size_t nthreads = contig ? in.data_size() : in.size();
|
size_t nthreads = contig ? in.data_size() : in.size();
|
||||||
bool large = nthreads > UINT32_MAX;
|
bool large = in.data_size() > UINT32_MAX;
|
||||||
|
if (!contig) {
|
||||||
|
large |= in.size() > UINT32_MAX;
|
||||||
|
}
|
||||||
int work_per_thread = !contig && large ? 4 : 1;
|
int work_per_thread = !contig && large ? 4 : 1;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
if (contig) {
|
if (contig) {
|
||||||
|
Loading…
Reference in New Issue
Block a user