mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Do vectorized store/load in binary_two ops
This commit is contained in:
@@ -17,47 +17,134 @@ namespace cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
|
if (remaining <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
auto out = Op{}(a[0], b[0]);
|
auto out = Op{}(a[0], b[0]);
|
||||||
out_a[0] = out[0];
|
out_a[offset] = out[0];
|
||||||
out_b[0] = out[1];
|
out_b[offset] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a[0], b[0]);
|
||||||
|
out_a_vec.val[i] = out[0];
|
||||||
|
out_b_vec.val[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
auto out = Op{}(a[0], b[index]);
|
if (remaining <= 0) {
|
||||||
out_a[index] = out[0];
|
return;
|
||||||
out_b[index] = out[1];
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
|
auto out = Op{}(a[0], b[offset]);
|
||||||
|
out_a[offset] = out[0];
|
||||||
|
out_b[offset] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a[0], b_vec.val[i]);
|
||||||
|
out_a_vec.val[i] = out[0];
|
||||||
|
out_b_vec.val[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
auto out = Op{}(a[index], b[0]);
|
if (remaining <= 0) {
|
||||||
out_a[index] = out[0];
|
return;
|
||||||
out_b[index] = out[1];
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
|
auto out = Op{}(a[offset], b[0]);
|
||||||
|
out_a[offset] = out[0];
|
||||||
|
out_b[offset] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec.val[i], b[0]);
|
||||||
|
out_a_vec.val[i] = out[0];
|
||||||
|
out_b_vec.val[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
int remaining = size - index * N_READS;
|
||||||
auto out = Op{}(a[index], b[index]);
|
if (remaining <= 0) {
|
||||||
out_a[index] = out[0];
|
return;
|
||||||
out_b[index] = out[1];
|
}
|
||||||
|
|
||||||
|
if (remaining < N_READS) {
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
IdxT offset = index * N_READS + i;
|
||||||
|
auto out = Op{}(a[offset], b[offset]);
|
||||||
|
out_a[offset] = out[0];
|
||||||
|
out_b[offset] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
|
||||||
|
out_a_vec.val[i] = out[0];
|
||||||
|
out_b_vec.val[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,20 +287,23 @@ void binary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
// TODO: Choose optimized value based on type size.
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
|
||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
kernel,
|
||||||
out_a.data_size(),
|
out_a.data_size(),
|
||||||
out_a.shape(),
|
out_a.shape(),
|
||||||
out_a.strides(),
|
out_a.strides(),
|
||||||
large());
|
large(),
|
||||||
|
N_READS);
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|||||||
Reference in New Issue
Block a user