mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster general unary op (#2472)
* faster general unary op * faster general ops + reorg * fix + comment * binary two * copy general
This commit is contained in:
@@ -10,33 +10,67 @@ namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
|
||||
__global__ void copy_g_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
IdxT size_rest,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape_x = shape[NDIM - 1];
|
||||
auto stride_x = strides[NDIM - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto idx =
|
||||
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void copy_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
IdxT size_rest,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto stride_x = strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto idx =
|
||||
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
@@ -61,30 +95,49 @@ void copy_general_input(
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
||||
num_blocks,
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
rest,
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_g<InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
|
||||
Reference in New Issue
Block a user