// Copyright © 2025 Apple Inc. #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void unary_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(in[i]); } } else { auto in_vec = load_vector(in, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(in_vec[i]); } store_vector(out, index, out_vec); } } template __global__ void unary_g( const In* in, Out* out, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto stride_x = strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto idx = elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); auto in_vec = load_vector(in + idx, index_x, shape_x, stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(in_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && !mlx::core::is_complex_v; } if (std::is_same_v) { return std::is_same_v && mlx::core::is_complex_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v || std::is_same_v) { return mlx::core::is_complex_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; } return false; } } // namespace cu template void unary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s) { auto& in = inputs[0]; if (in.size() == 0) { return; } bool contig = in.flags().contiguous; bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { dispatch_bool(large, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; if (contig) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(OutType); auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large, N_READS); encoder.add_kernel_node( cu::unary_v, num_blocks, block_dims, 0, in.data(), out.data(), out.data_size()); } else { using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); auto ndim = shape.size(); int work_per_thread = 1; auto kernel = cu::unary_g; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out.size() / dim0; if (dim0 >= 4) { kernel = cu::unary_g; work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, 0, in.data(), out.data(), rest, const_param(shape), const_param(strides), ndim); } }); } else { throw std::runtime_error(fmt::format( "Can not do unary op {} on input of {} with output of {}.", op, dtype_to_string(in.dtype()), dtype_to_string(out.dtype()))); } }); }); } template void unary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s) { set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); } #define UNARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ nvtx3::scoped_range r(#func "::eval_gpu"); \ auto& s = out.primitive().stream(); \ unary_op_gpu(inputs, out, name(), s); \ } UNARY_GPU(Abs) UNARY_GPU(ArcCos) UNARY_GPU(ArcCosh) UNARY_GPU(ArcSin) UNARY_GPU(ArcSinh) UNARY_GPU(ArcTan) UNARY_GPU(ArcTanh) UNARY_GPU(BitwiseInvert) UNARY_GPU(Ceil) UNARY_GPU(Conjugate) UNARY_GPU(Cos) UNARY_GPU(Cosh) UNARY_GPU(Erf) UNARY_GPU(ErfInv) UNARY_GPU(Exp) UNARY_GPU(Expm1) UNARY_GPU(Floor) UNARY_GPU(Imag) UNARY_GPU(Log1p) UNARY_GPU(LogicalNot) UNARY_GPU(Negative) UNARY_GPU(Real) UNARY_GPU(Sigmoid) UNARY_GPU(Sign) UNARY_GPU(Sin) UNARY_GPU(Sinh) UNARY_GPU(Square) UNARY_GPU(Tan) UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Log::eval_gpu"); auto& s = out.primitive().stream(); switch (base_) { case Base::e: unary_op_gpu(inputs, out, name(), s); break; case Base::two: unary_op_gpu(inputs, out, name(), s); break; case Base::ten: unary_op_gpu(inputs, out, name(), s); break; } } void Round::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Round::eval_gpu"); assert(inputs.size() == 1); const auto& in = inputs[0]; auto& s = out.primitive().stream(); if (issubdtype(in.dtype(), inexact)) { unary_op_gpu(inputs, out, name(), s); } else { // No-op integer types out.copy_shared_buffer(in); } } void Sqrt::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Sort::eval_gpu"); auto& s = out.primitive().stream(); if (recip_) { unary_op_gpu(inputs, out, "Rsqrt", s); } else { unary_op_gpu(inputs, out, "Sqrt", s); } } } // namespace mlx::core