diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 084e449b0..0d526400d 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -8,7 +8,6 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu - ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu @@ -45,12 +44,14 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu - ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) diff --git a/mlx/backend/cuda/binary/CMakeLists.txt b/mlx/backend/cuda/binary/CMakeLists.txt new file mode 100644 index 000000000..bda289de7 --- /dev/null +++ b/mlx/backend/cuda/binary/CMakeLists.txt @@ -0,0 +1,21 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu) diff --git a/mlx/backend/cuda/binary/add.cu b/mlx/backend/cuda/binary/add.cu new file mode 100644 index 000000000..87dfd7e70 --- /dev/null +++ b/mlx/backend/cuda/binary/add.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Add) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/arctan2.cu b/mlx/backend/cuda/binary/arctan2.cu new file mode 100644 index 000000000..2fd7e3922 --- /dev/null +++ b/mlx/backend/cuda/binary/arctan2.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(ArcTan2) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary/binary.cuh similarity index 72% rename from mlx/backend/cuda/binary.cu rename to mlx/backend/cuda/binary/binary.cuh index 0243d4f41..20bb199ec 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary/binary.cuh @@ -99,39 +99,89 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { } } -template +template < + typename Op, + typename In, + typename Out, + typename IdxT, + int NDIM, + int N_READS> __global__ void binary_g_nd( const In* a, const In* b, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc_nd( - index, shape.data(), a_strides.data(), b_strides.data()); - out[index] = Op{}(a[a_idx], b[b_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto a_stride_x = a_strides[NDIM - 1]; + auto b_stride_x = b_strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc_nd( + index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } -template +template __global__ void binary_g( const In* a, const In* b, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc( - index, shape.data(), a_strides.data(), b_strides.data(), ndim); - out[index] = Op{}(a[a_idx], b[b_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto a_stride_x = a_strides[ndim - 1]; + auto b_stride_x = b_strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + ndim); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template @@ -209,39 +259,61 @@ void binary_op_gpu_inplace( auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = - get_launch_args(out, large()); + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 1>; + if (work_per_thread == 4) { + kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 4>; + } encoder.add_kernel_node( - cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out.data(), - out.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = cu::binary_g; + if (work_per_thread == 4) { + kernel = cu::binary_g; + } encoder.add_kernel_node( - cu::binary_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out.data(), - out.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), @@ -304,54 +376,4 @@ void binary_op_gpu( binary_op_gpu(inputs, out, name(), s); \ } -BINARY_GPU(Add) -BINARY_GPU(ArcTan2) -BINARY_GPU(Divide) -BINARY_GPU(Remainder) -BINARY_GPU(Greater) -BINARY_GPU(GreaterEqual) -BINARY_GPU(Less) -BINARY_GPU(LessEqual) -BINARY_GPU(LogicalAnd) -BINARY_GPU(LogicalOr) -BINARY_GPU(LogAddExp) -BINARY_GPU(Maximum) -BINARY_GPU(Minimum) -BINARY_GPU(Multiply) -BINARY_GPU(NotEqual) -BINARY_GPU(Power) -BINARY_GPU(Subtract) - -void Equal::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Equal::eval_gpu"); - auto& s = out.primitive().stream(); - if (equal_nan_) { - binary_op_gpu(inputs, out, name(), s); - } else { - binary_op_gpu(inputs, out, name(), s); - } -} - -void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); - auto& s = out.primitive().stream(); - switch (op_) { - case BitwiseBinary::And: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::Or: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, name(), s); - break; - } -} - } // namespace mlx::core diff --git a/mlx/backend/cuda/binary/bitwise_binary.cu b/mlx/backend/cuda/binary/bitwise_binary.cu new file mode 100644 index 000000000..8025a3bd1 --- /dev/null +++ b/mlx/backend/cuda/binary/bitwise_binary.cu @@ -0,0 +1,27 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); + auto& s = out.primitive().stream(); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, name(), s); + break; + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/divide.cu b/mlx/backend/cuda/binary/divide.cu new file mode 100644 index 000000000..fcf3dc77e --- /dev/null +++ b/mlx/backend/cuda/binary/divide.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Divide) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/equal.cu b/mlx/backend/cuda/binary/equal.cu new file mode 100644 index 000000000..559b3e8ed --- /dev/null +++ b/mlx/backend/cuda/binary/equal.cu @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +void Equal::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Equal::eval_gpu"); + auto& s = out.primitive().stream(); + if (equal_nan_) { + binary_op_gpu(inputs, out, name(), s); + } else { + binary_op_gpu(inputs, out, name(), s); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/greater.cu b/mlx/backend/cuda/binary/greater.cu new file mode 100644 index 000000000..c9820206b --- /dev/null +++ b/mlx/backend/cuda/binary/greater.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Greater) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/greater_equal.cu b/mlx/backend/cuda/binary/greater_equal.cu new file mode 100644 index 000000000..4666fb4a9 --- /dev/null +++ b/mlx/backend/cuda/binary/greater_equal.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(GreaterEqual) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/less.cu b/mlx/backend/cuda/binary/less.cu new file mode 100644 index 000000000..a2053fa8b --- /dev/null +++ b/mlx/backend/cuda/binary/less.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Less) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/less_equal.cu b/mlx/backend/cuda/binary/less_equal.cu new file mode 100644 index 000000000..7f9bc5161 --- /dev/null +++ b/mlx/backend/cuda/binary/less_equal.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LessEqual) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/log_add_exp.cu b/mlx/backend/cuda/binary/log_add_exp.cu new file mode 100644 index 000000000..17614f862 --- /dev/null +++ b/mlx/backend/cuda/binary/log_add_exp.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LogAddExp) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/logical_and.cu b/mlx/backend/cuda/binary/logical_and.cu new file mode 100644 index 000000000..6bbeb1a4c --- /dev/null +++ b/mlx/backend/cuda/binary/logical_and.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LogicalAnd) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/logical_or.cu b/mlx/backend/cuda/binary/logical_or.cu new file mode 100644 index 000000000..63afdb98c --- /dev/null +++ b/mlx/backend/cuda/binary/logical_or.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LogicalOr) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/maximum.cu b/mlx/backend/cuda/binary/maximum.cu new file mode 100644 index 000000000..4f6cb6e0b --- /dev/null +++ b/mlx/backend/cuda/binary/maximum.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Maximum) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/minimum.cu b/mlx/backend/cuda/binary/minimum.cu new file mode 100644 index 000000000..ec4c1abb0 --- /dev/null +++ b/mlx/backend/cuda/binary/minimum.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Minimum) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/multiply.cu b/mlx/backend/cuda/binary/multiply.cu new file mode 100644 index 000000000..bfc15fcaa --- /dev/null +++ b/mlx/backend/cuda/binary/multiply.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Multiply) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/not_equal.cu b/mlx/backend/cuda/binary/not_equal.cu new file mode 100644 index 000000000..49f05c90a --- /dev/null +++ b/mlx/backend/cuda/binary/not_equal.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(NotEqual) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/power.cu b/mlx/backend/cuda/binary/power.cu new file mode 100644 index 000000000..cacdc75c4 --- /dev/null +++ b/mlx/backend/cuda/binary/power.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Power) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/remainder.cu b/mlx/backend/cuda/binary/remainder.cu new file mode 100644 index 000000000..a55006ba0 --- /dev/null +++ b/mlx/backend/cuda/binary/remainder.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Remainder) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/subtract.cu b/mlx/backend/cuda/binary/subtract.cu new file mode 100644 index 000000000..37f3874cc --- /dev/null +++ b/mlx/backend/cuda/binary/subtract.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Subtract) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 49a747829..cd0fe2c46 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -127,45 +127,99 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { } } -template +template < + typename Op, + typename In, + typename Out, + typename IdxT, + int NDIM, + int N_READS> __global__ void binary_two_g_nd( const In* a, const In* b, Out* out_a, Out* out_b, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc_nd( - index, shape.data(), a_strides.data(), b_strides.data()); - auto out = Op{}(a[a_idx], b[b_idx]); - out_a[index] = out[0]; - out_b[index] = out[1]; + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto a_stride_x = a_strides[NDIM - 1]; + auto b_stride_x = b_strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc_nd( + index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec_a; + AlignedVector out_vec_b; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec[i], b_vec[i]); + out_vec_a[i] = out[0]; + out_vec_b[i] = out[1]; + } + store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); + store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } -template +template __global__ void binary_two_g( const In* a, const In* b, Out* out_a, Out* out_b, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc( - index, shape.data(), a_strides.data(), b_strides.data(), ndim); - auto out = Op{}(a[a_idx], b[b_idx]); - out_a[index] = out[0]; - out_b[index] = out[1]; + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto a_stride_x = a_strides[ndim - 1]; + auto b_stride_x = b_strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + ndim); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec_a; + AlignedVector out_vec_b; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec[i], b_vec[i]); + out_vec_a[i] = out[0]; + out_vec_b[i] = out[1]; + } + store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); + store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } template @@ -225,42 +279,64 @@ void binary_two_op_gpu_inplace( auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out_a.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = - get_launch_args(out_a, large()); + auto kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 1>; + if (work_per_thread == 4) { + kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 4>; + } encoder.add_kernel_node( - cu::binary_two_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out_a.data(), out_b.data(), - out_a.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { - auto [num_blocks, block_dims] = - get_launch_args(out_a, large()); + auto kernel = cu::binary_two_g; + if (work_per_thread == 4) { + kernel = cu::binary_two_g; + } encoder.add_kernel_node( - cu::binary_two_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out_a.data(), out_b.data(), - out_a.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 64c67a176..6ac42751a 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -10,37 +10,80 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_gg_nd( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array strides_in, const __grid_constant__ cuda::std::array strides_out) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [idx_in, idx_out] = elem_to_loc_nd( - index, shape.data(), strides_in.data(), strides_out.data()); - out[idx_out] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto in_stride_x = strides_in[NDIM - 1]; + auto out_stride_x = strides_out[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [idx_in, idx_out] = elem_to_loc_nd( + index_rest * shape_x, + shape.data(), + strides_in.data(), + strides_out.data()); + + auto in_vec = + load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } -template +template __global__ void copy_gg( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides_out, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [idx_in, idx_out] = elem_to_loc( - index, shape.data(), strides_in.data(), strides_out.data(), ndim); - out[idx_out] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto in_stride_x = strides_in[ndim - 1]; + auto out_stride_x = strides_out[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [idx_in, idx_out] = elem_to_loc( + index_rest * shape_x, + shape.data(), + strides_in.data(), + strides_out.data(), + ndim); + + auto in_vec = + load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } } // namespace cu @@ -69,33 +112,52 @@ void copy_general( size_t data_size = 1; for (auto& s : shape) data_size *= s; + + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = data_size / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) { - auto [num_blocks, block_dims] = - get_launch_args(data_size, shape, out.strides(), large()); + auto kernel = + cu::copy_gg_nd; + if (work_per_thread == 4) { + kernel = + cu::copy_gg_nd; + } encoder.add_kernel_node( - cu::copy_gg_nd, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - data_size, + rest, const_param(shape), const_param(strides_in), const_param(strides_out)); }); } else { // ndim >= 4 - auto [num_blocks, block_dims] = - get_launch_args(data_size, shape, out.strides(), large()); + auto kernel = cu::copy_gg; + if (work_per_thread == 4) { + kernel = cu::copy_gg; + } encoder.add_kernel_node( - cu::copy_gg, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - data_size, + rest, const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index f381f14fa..ce8bb1b78 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -10,33 +10,67 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_g_nd( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, - const __grid_constant__ cuda::std::array strides_in) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - IdxT idx_in = elem_to_loc_nd(index, shape.data(), strides_in.data()); - out[index] = CastOp{}(in[idx_in]); + const __grid_constant__ cuda::std::array strides) { + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto stride_x = strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc_nd(index_rest * shape_x, shape.data(), strides.data()); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } -template +template __global__ void copy_g( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, - const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim); - out[index] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } } // namespace cu @@ -61,30 +95,49 @@ void copy_general_input( const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = + cu::copy_g_nd; + if (work_per_thread == 4) { + kernel = + cu::copy_g_nd; + } encoder.add_kernel_node( - cu::copy_g_nd, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - out.size(), + rest, const_param(shape), const_param(strides_in)); }); } else { // ndim >= 4 - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = cu::copy_g; + if (work_per_thread == 4) { + kernel = cu::copy_g; + } encoder.add_kernel_node( - cu::copy_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - out.size(), + rest, const_param(shape), const_param(strides_in), ndim); diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index bc055c9df..7ebc5d654 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -146,6 +146,23 @@ inline __device__ void store_vector( } } +template +inline __device__ void store_vector( + T* ptr, + uint32_t offset, + const AlignedVector& vec, + SizeT size, + int64_t stride) { + if (is_aligned(ptr) && (offset + 1) * N <= size && stride == 1) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { + for (int i = 0; (offset * N + i) < size && i < N; ++i) { + ptr[stride * (offset * N + i)] = vec[i]; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index cfc0e10b8..67937fc8e 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -39,52 +39,98 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { } } -template +template __global__ void ternary_g_nd( const bool* a, const T* b, const T* c, T* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides, const __grid_constant__ cuda::std::array c_strides) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( - index, - shape.data(), - a_strides.data(), - b_strides.data(), - c_strides.data()); - out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto a_stride_x = a_strides[NDIM - 1]; + auto b_stride_x = b_strides[NDIM - 1]; + auto c_stride_x = c_strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data()); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, false); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, T(0)); + auto c_vec = + load_vector(c + c_idx, index_x, shape_x, c_stride_x, T(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } -template +template __global__ void ternary_g( const bool* a, const T* b, const T* c, T* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, const __grid_constant__ Strides c_strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx, c_idx] = elem_to_loc( - index, - shape.data(), - a_strides.data(), - b_strides.data(), - c_strides.data(), - ndim); - out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto a_stride_x = a_strides[ndim - 1]; + auto b_stride_x = b_strides[ndim - 1]; + auto c_stride_x = c_strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx, c_idx] = elem_to_loc( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data(), + ndim); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, false); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, T(0)); + auto c_vec = + load_vector(c + c_idx, index_x, shape_x, c_stride_x, T(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } } // namespace cu @@ -123,36 +169,55 @@ void ternary_op_gpu_inplace( auto& b_strides = strides[1]; auto& c_strides = strides[2]; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = + cu::ternary_g_nd; + if (work_per_thread == 4) { + kernel = + cu::ternary_g_nd; + } encoder.add_kernel_node( - cu::ternary_g_nd, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), c.data(), out.data(), - out.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), const_param(c_strides)); }); } else { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = cu::ternary_g; + if (work_per_thread == 4) { + kernel = cu::ternary_g; + } encoder.add_kernel_node( - cu::ternary_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), c.data(), out.data(), - out.data_size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 96888da97..4102dbfb3 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -37,19 +37,36 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) { } } -template +template __global__ void unary_g( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim); - out[index] = Op{}(in[idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template @@ -127,8 +144,7 @@ void unary_op_gpu_inplace( using OutType = cuda_type_t; if (contig) { using IdxT = std::conditional_t; - // TODO: Choose optimized value based on type size. - constexpr int N_READS = 4; + constexpr int N_READS = 16 / sizeof(OutType); auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large, N_READS); encoder.add_kernel_node( @@ -142,18 +158,30 @@ void unary_op_gpu_inplace( } else { using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); - auto [num_blocks, block_dims] = get_launch_args(out, large); + auto ndim = shape.size(); + int work_per_thread = 1; + auto kernel = cu::unary_g; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + kernel = cu::unary_g; + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); encoder.add_kernel_node( - cu::unary_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in.data(), out.data(), - out.data_size(), + rest, const_param(shape), const_param(strides), - shape.size()); + ndim); } }); } else { diff --git a/mlx/backend/cuda/unary/CMakeLists.txt b/mlx/backend/cuda/unary/CMakeLists.txt new file mode 100644 index 000000000..532c5645e --- /dev/null +++ b/mlx/backend/cuda/unary/CMakeLists.txt @@ -0,0 +1,34 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/abs.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccos.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccosh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsin.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsinh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctanh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_invert.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ceil.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conjugate.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cos.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/imag.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log1p.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_not.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/negative.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/real.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/round.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sign.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sin.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sinh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/square.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tan.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cu) diff --git a/mlx/backend/cuda/unary/abs.cu b/mlx/backend/cuda/unary/abs.cu new file mode 100644 index 000000000..90b197d21 --- /dev/null +++ b/mlx/backend/cuda/unary/abs.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Abs) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arccos.cu b/mlx/backend/cuda/unary/arccos.cu new file mode 100644 index 000000000..38849970d --- /dev/null +++ b/mlx/backend/cuda/unary/arccos.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcCos) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arccosh.cu b/mlx/backend/cuda/unary/arccosh.cu new file mode 100644 index 000000000..0ef0738a4 --- /dev/null +++ b/mlx/backend/cuda/unary/arccosh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcCosh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arcsin.cu b/mlx/backend/cuda/unary/arcsin.cu new file mode 100644 index 000000000..07956ee9b --- /dev/null +++ b/mlx/backend/cuda/unary/arcsin.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcSin) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arcsinh.cu b/mlx/backend/cuda/unary/arcsinh.cu new file mode 100644 index 000000000..a7ab63e17 --- /dev/null +++ b/mlx/backend/cuda/unary/arcsinh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcSinh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arctan.cu b/mlx/backend/cuda/unary/arctan.cu new file mode 100644 index 000000000..78639afaa --- /dev/null +++ b/mlx/backend/cuda/unary/arctan.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcTan) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arctanh.cu b/mlx/backend/cuda/unary/arctanh.cu new file mode 100644 index 000000000..488268c9e --- /dev/null +++ b/mlx/backend/cuda/unary/arctanh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcTanh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/bitwise_invert.cu b/mlx/backend/cuda/unary/bitwise_invert.cu new file mode 100644 index 000000000..77b88f30f --- /dev/null +++ b/mlx/backend/cuda/unary/bitwise_invert.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(BitwiseInvert) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/ceil.cu b/mlx/backend/cuda/unary/ceil.cu new file mode 100644 index 000000000..5ee300ffe --- /dev/null +++ b/mlx/backend/cuda/unary/ceil.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Ceil) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/conjugate.cu b/mlx/backend/cuda/unary/conjugate.cu new file mode 100644 index 000000000..1d1d60e77 --- /dev/null +++ b/mlx/backend/cuda/unary/conjugate.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Conjugate) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/cos.cu b/mlx/backend/cuda/unary/cos.cu new file mode 100644 index 000000000..cfceb86ab --- /dev/null +++ b/mlx/backend/cuda/unary/cos.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Cos) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/cosh.cu b/mlx/backend/cuda/unary/cosh.cu new file mode 100644 index 000000000..d5fcc7081 --- /dev/null +++ b/mlx/backend/cuda/unary/cosh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Cosh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/erf.cu b/mlx/backend/cuda/unary/erf.cu new file mode 100644 index 000000000..c7859322b --- /dev/null +++ b/mlx/backend/cuda/unary/erf.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Erf) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/erf_inv.cu b/mlx/backend/cuda/unary/erf_inv.cu new file mode 100644 index 000000000..16bbaba19 --- /dev/null +++ b/mlx/backend/cuda/unary/erf_inv.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ErfInv) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/exp.cu b/mlx/backend/cuda/unary/exp.cu new file mode 100644 index 000000000..5a566691d --- /dev/null +++ b/mlx/backend/cuda/unary/exp.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Exp) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/expm1.cu b/mlx/backend/cuda/unary/expm1.cu new file mode 100644 index 000000000..15e6ce445 --- /dev/null +++ b/mlx/backend/cuda/unary/expm1.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Expm1) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/floor.cu b/mlx/backend/cuda/unary/floor.cu new file mode 100644 index 000000000..a8c7ab0bb --- /dev/null +++ b/mlx/backend/cuda/unary/floor.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Floor) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/imag.cu b/mlx/backend/cuda/unary/imag.cu new file mode 100644 index 000000000..9e3c05c3b --- /dev/null +++ b/mlx/backend/cuda/unary/imag.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Imag) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/log.cu b/mlx/backend/cuda/unary/log.cu new file mode 100644 index 000000000..1fd2aa680 --- /dev/null +++ b/mlx/backend/cuda/unary/log.cu @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +void Log::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Log::eval_gpu"); + auto& s = out.primitive().stream(); + switch (base_) { + case Base::e: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::two: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::ten: + unary_op_gpu(inputs, out, name(), s); + break; + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/log1p.cu b/mlx/backend/cuda/unary/log1p.cu new file mode 100644 index 000000000..5396c3da0 --- /dev/null +++ b/mlx/backend/cuda/unary/log1p.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Log1p) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/logical_not.cu b/mlx/backend/cuda/unary/logical_not.cu new file mode 100644 index 000000000..7f398707f --- /dev/null +++ b/mlx/backend/cuda/unary/logical_not.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(LogicalNot) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/negative.cu b/mlx/backend/cuda/unary/negative.cu new file mode 100644 index 000000000..9c7e576ec --- /dev/null +++ b/mlx/backend/cuda/unary/negative.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Negative) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/real.cu b/mlx/backend/cuda/unary/real.cu new file mode 100644 index 000000000..361ffd3f9 --- /dev/null +++ b/mlx/backend/cuda/unary/real.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Real) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/round.cu b/mlx/backend/cuda/unary/round.cu new file mode 100644 index 000000000..4e80fdb60 --- /dev/null +++ b/mlx/backend/cuda/unary/round.cu @@ -0,0 +1,18 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +void Round::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Round::eval_gpu"); + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto& s = out.primitive().stream(); + if (issubdtype(in.dtype(), inexact)) { + unary_op_gpu(inputs, out, name(), s); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sigmoid.cu b/mlx/backend/cuda/unary/sigmoid.cu new file mode 100644 index 000000000..3d943726c --- /dev/null +++ b/mlx/backend/cuda/unary/sigmoid.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sigmoid) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sign.cu b/mlx/backend/cuda/unary/sign.cu new file mode 100644 index 000000000..d586d8275 --- /dev/null +++ b/mlx/backend/cuda/unary/sign.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sign) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sin.cu b/mlx/backend/cuda/unary/sin.cu new file mode 100644 index 000000000..47a5adc84 --- /dev/null +++ b/mlx/backend/cuda/unary/sin.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sin) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sinh.cu b/mlx/backend/cuda/unary/sinh.cu new file mode 100644 index 000000000..7a73b7fd4 --- /dev/null +++ b/mlx/backend/cuda/unary/sinh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sinh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sqrt.cu b/mlx/backend/cuda/unary/sqrt.cu new file mode 100644 index 000000000..21f5f08f2 --- /dev/null +++ b/mlx/backend/cuda/unary/sqrt.cu @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sqrt::eval_gpu"); + auto& s = out.primitive().stream(); + if (recip_) { + unary_op_gpu(inputs, out, "Rsqrt", s); + } else { + unary_op_gpu(inputs, out, "Sqrt", s); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/square.cu b/mlx/backend/cuda/unary/square.cu new file mode 100644 index 000000000..bbb5f5130 --- /dev/null +++ b/mlx/backend/cuda/unary/square.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Square) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/tan.cu b/mlx/backend/cuda/unary/tan.cu new file mode 100644 index 000000000..3039dcdc1 --- /dev/null +++ b/mlx/backend/cuda/unary/tan.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Tan) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/tanh.cu b/mlx/backend/cuda/unary/tanh.cu new file mode 100644 index 000000000..ae69a51b5 --- /dev/null +++ b/mlx/backend/cuda/unary/tanh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Tanh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/unary.cuh b/mlx/backend/cuda/unary/unary.cuh new file mode 100644 index 000000000..a20e119ca --- /dev/null +++ b/mlx/backend/cuda/unary/unary.cuh @@ -0,0 +1,215 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void unary_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(in_vec[i]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void unary_g( + const In* in, + Out* out, + IdxT size_rest, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides, + int ndim) { + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; + } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); +} + +template +constexpr bool supports_unary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && !mlx::core::is_complex_v; + } + if (std::is_same_v) { + return std::is_same_v && mlx::core::is_complex_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v || std::is_same_v) { + return mlx::core::is_complex_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + auto& in = inputs[0]; + if (in.size() == 0) { + return; + } + bool contig = in.flags().contiguous; + bool large; + if (!contig) { + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_unary_op()) { + dispatch_bool(large, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + if (contig) { + using IdxT = std::conditional_t; + constexpr int N_READS = 16 / sizeof(OutType); + auto [num_blocks, block_dims] = get_launch_args( + out.data_size(), out.shape(), out.strides(), large, N_READS); + encoder.add_kernel_node( + cu::unary_v, + num_blocks, + block_dims, + 0, + in.data(), + out.data(), + out.data_size()); + } else { + using IdxT = std::conditional_t; + auto [shape, strides] = collapse_contiguous_dims(in); + auto ndim = shape.size(); + int work_per_thread = 1; + auto kernel = cu::unary_g; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + kernel = cu::unary_g; + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + encoder.add_kernel_node( + kernel, + {num_blocks_x, num_blocks_y}, + block_dims, + 0, + in.data(), + out.data(), + rest, + const_param(shape), + const_param(strides), + ndim); + } + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +template +void unary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + set_unary_output_data(inputs[0], out); + unary_op_gpu_inplace(inputs, out, op, s); +} + +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, name(), s); \ + } + +} // namespace mlx::core