// Copyright © 2025 Apple Inc. #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include namespace mlx::core { namespace cu { template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && !std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; } return false; } } // namespace cu template void unary_op_gpu_inplace( const std::vector& inputs, array& out, const std::string& op, const Stream& s) { auto& in = inputs[0]; if (in.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto policy = cu::thrust_policy(stream); auto in_ptr = thrust::device_pointer_cast(in.data()); auto out_ptr = thrust::device_pointer_cast(out.data()); if (in.flags().contiguous) { thrust::transform( policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); } else { auto [shape, strides] = collapse_contiguous_dims(in); auto [in_begin, in_end] = cu::make_general_iterators( in_ptr, in.size(), shape, strides); thrust::transform(policy, in_begin, in_end, out_ptr, Op()); } } else { throw std::runtime_error(fmt::format( "Can not do unary op {} on input of {} with output of {}.", op, dtype_to_string(in.dtype()), dtype_to_string(out.dtype()))); } }); }); }); } template void unary_op_gpu( const std::vector& inputs, array& out, const std::string& op, const Stream& s) { set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); } #define UNARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ nvtx3::scoped_range r(#func "::eval_gpu"); \ auto& s = out.primitive().stream(); \ unary_op_gpu(inputs, out, get_primitive_string(this), s); \ } UNARY_GPU(Abs) UNARY_GPU(ArcCos) UNARY_GPU(ArcCosh) UNARY_GPU(ArcSin) UNARY_GPU(ArcSinh) UNARY_GPU(ArcTan) UNARY_GPU(ArcTanh) UNARY_GPU(BitwiseInvert) UNARY_GPU(Ceil) UNARY_GPU(Conjugate) UNARY_GPU(Cos) UNARY_GPU(Cosh) UNARY_GPU(Erf) UNARY_GPU(ErfInv) UNARY_GPU(Exp) UNARY_GPU(Expm1) UNARY_GPU(Floor) UNARY_GPU(Imag) UNARY_GPU(Log1p) UNARY_GPU(LogicalNot) UNARY_GPU(Negative) UNARY_GPU(Real) UNARY_GPU(Sigmoid) UNARY_GPU(Sign) UNARY_GPU(Sin) UNARY_GPU(Sinh) UNARY_GPU(Square) UNARY_GPU(Tan) UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Log::eval_gpu"); auto& s = out.primitive().stream(); auto op = get_primitive_string(this); switch (base_) { case Base::e: unary_op_gpu(inputs, out, op, s); break; case Base::two: unary_op_gpu(inputs, out, op, s); break; case Base::ten: unary_op_gpu(inputs, out, op, s); break; } } void Round::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Round::eval_gpu"); assert(inputs.size() == 1); const auto& in = inputs[0]; auto& s = out.primitive().stream(); if (issubdtype(in.dtype(), inexact)) { unary_op_gpu(inputs, out, get_primitive_string(this), s); } else { // No-op integer types out.copy_shared_buffer(in); } } void Sqrt::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Sort::eval_gpu"); auto& s = out.primitive().stream(); if (recip_) { unary_op_gpu(inputs, out, "Rsqrt", s); } else { unary_op_gpu(inputs, out, "Sqrt", s); } } } // namespace mlx::core