mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
Do vectorized store/load in unary ops
This commit is contained in:
@@ -18,11 +18,29 @@ namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void unary_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(in[index]);
|
||||
int remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (remaining < N_READS) {
|
||||
for (int i = 0; i < remaining; ++i) {
|
||||
IdxT offset = index * N_READS + i;
|
||||
out[offset] = Op{}(in[offset]);
|
||||
}
|
||||
} else {
|
||||
auto in_vec = load_vector<N_READS>(in, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(in_vec.val[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,9 +135,16 @@ void unary_op_gpu_inplace(
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
if (contig) {
|
||||
auto kernel = cu::unary_v<Op, InType, OutType, IdxT>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large);
|
||||
kernel,
|
||||
out.data_size(),
|
||||
out.shape(),
|
||||
out.strides(),
|
||||
large,
|
||||
N_READS);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
|
Reference in New Issue
Block a user