From 194212f65feb41724753d3d0f4d5c76ffdc3b40b Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 13 Apr 2025 23:51:11 +0000 Subject: [PATCH] CUDA backend: binary ops --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/binary.cu | 212 ++++++++++++++ .../cuda/iterators/repeat_iterator.cuh | 31 ++ mlx/backend/cuda/kernels/binary_ops.cuh | 275 ++++++++++++++++++ mlx/backend/cuda/kernels/fp16_math.cuh | 46 +++ mlx/backend/cuda/primitives.cu | 19 -- 6 files changed, 565 insertions(+), 19 deletions(-) create mode 100644 mlx/backend/cuda/binary.cu create mode 100644 mlx/backend/cuda/iterators/repeat_iterator.cuh create mode 100644 mlx/backend/cuda/kernels/binary_ops.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8c5f16ca3..1a3e95059 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu new file mode 100644 index 000000000..afffc1ddf --- /dev/null +++ b/mlx/backend/cuda/binary.cu @@ -0,0 +1,212 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/iterators/repeat_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/binary_ops.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +template +constexpr bool supports_binary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && + (is_floating_v || std::is_same_v); + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && std::is_integral_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& out = outputs[0]; + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + auto policy = cu::thrust_policy(stream); + auto a_ptr = thrust::device_pointer_cast(a.data()); + auto b_ptr = thrust::device_pointer_cast(b.data()); + auto out_ptr = thrust::device_pointer_cast(out.data()); + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::ScalarScalar) { + auto a_begin = cu::repeat_iterator(a_ptr); + auto a_end = a_begin + out.data_size(); + auto b_begin = cu::repeat_iterator(b_ptr); + thrust::transform(policy, a_begin, a_end, b_begin, out_ptr, Op()); + } else if (bopt == BinaryOpType::ScalarVector) { + auto a_begin = cu::repeat_iterator(a_ptr); + auto a_end = a_begin + out.data_size(); + auto b_begin = b_ptr; + thrust::transform(policy, a_begin, a_end, b_begin, out_ptr, Op()); + } else if (bopt == BinaryOpType::VectorScalar) { + auto a_begin = a_ptr; + auto a_end = a_begin + out.data_size(); + auto b_begin = cu::repeat_iterator(b_ptr); + thrust::transform(policy, a_begin, a_end, b_begin, out_ptr, Op()); + } else if (bopt == BinaryOpType::VectorVector) { + auto a_begin = a_ptr; + auto a_end = a_begin + out.data_size(); + auto b_begin = b_ptr; + thrust::transform(policy, a_begin, a_end, b_begin, out_ptr, Op()); + } else { + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto [a_begin, a_end] = cu::make_general_iterators( + a_ptr, out.data_size(), shape, strides[0]); + auto [b_begin, b_end] = cu::make_general_iterators( + b_ptr, out.data_size(), shape, strides[1]); + thrust::transform(policy, a_begin, a_end, b_begin, out_ptr, Op()); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +template +void binary_op_gpu( + const std::vector& inputs, + array& out, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + std::vector outputs{out}; + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +#define BINARY_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = outputs[0].primitive().stream(); \ + binary_op_gpu(inputs, outputs, get_primitive_string(this), s); \ + } + +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU(Remainder) +BINARY_GPU(Equal) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(LogAddExp) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Subtract) + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, op, s); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/iterators/repeat_iterator.cuh b/mlx/backend/cuda/iterators/repeat_iterator.cuh new file mode 100644 index 000000000..48c2a9de7 --- /dev/null +++ b/mlx/backend/cuda/iterators/repeat_iterator.cuh @@ -0,0 +1,31 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::cu { + +// Always return the value of initial iterator after advancements. +template +class repeat_iterator + : public thrust::iterator_adaptor, Iterator> { + public: + using super_t = thrust::iterator_adaptor, Iterator>; + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ repeat_iterator(Iterator it) : super_t(it), it_(it) {} + + private: + friend class thrust::iterator_core_access; + + // The dereference is device-only to avoid accidental running in host. + __device__ typename super_t::reference dereference() const { + return *it_; + } + + Iterator it_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/binary_ops.cuh b/mlx/backend/cuda/kernels/binary_ops.cuh new file mode 100644 index 000000000..6076679c1 --- /dev/null +++ b/mlx/backend/cuda/kernels/binary_ops.cuh @@ -0,0 +1,275 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/kernels/fp16_math.cuh" + +namespace mlx::core::cu { + +struct Add { + template + __device__ T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return x / y; + } else { + return trunc(x / y); + } + } +}; + +struct Divide { + template + __device__ T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + if constexpr (cuda::std::is_signed_v) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } else { + return x % y; + } + } else if constexpr (cuda::std::is_same_v) { + return x % y; + } else { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r = r + y; + } + return r; + } + } +}; + +struct Equal { + template + __device__ bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return x == y || + (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) && + isnan(cuCimagf(y))) || + (cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) && + isnan(cuCimagf(y))) || + (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && + cuCimagf(x) == cuCimagf(y)); + } else { + return x == y || (isnan(x) && isnan(y)); + } + } +}; + +struct Greater { + template + __device__ bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + __device__ bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + __device__ bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + __device__ T operator()(T x, T y) { + if (isnan(x) || isnan(y)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + T maxval = max(x, y); + T minval = min(x, y); + return (minval == -cuda::std::numeric_limits::infinity() || + maxval == cuda::std::numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1p(expf(minval - maxval))); + }; +}; + +struct Maximum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return max(x, y); + } else if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + return x; + } + return x > y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x > y ? x : y; + } + } +}; + +struct Minimum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return min(x, y); + } else if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + return x; + } + return x < y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x < y ? x : y; + } + } +}; + +struct Multiply { + template + __device__ T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y); + } else { + return x != y; + } + } +}; + +struct Power { + template + __device__ T operator()(T base, T exp) { + if constexpr (cuda::std::is_integral_v) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } else if constexpr (cuda::std::is_same_v) { + auto x_theta = atan2f(base.y, base.x); + auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); + auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); + auto phase = exp.y * x_ln_r + exp.x * x_theta; + return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase)); + } else { + return powf(base, exp); + } + } +}; + +struct Subtract { + template + __device__ T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + __device__ T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + __device__ T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + __device__ T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + __device__ T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + __device__ T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + __device__ T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + __device__ T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + __device__ T operator()(T y, T x) { + return atan2f(y, x); + } +}; + +struct DivMod { + template + __device__ cuda::std::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index cf5def4db..f6fa17bb9 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -81,6 +81,52 @@ MLX_DEFINE_UNARY_OP_FALLBCK(tanh) #undef MLX_DEFINE_UNARY_OP #undef MLX_DEFINE_UNARY_OP_FALLBCK +/////////////////////////////////////////////////////////////////////////////// +// Binary ops for half types. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#else +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#endif + +MLX_DEFINE_BINARY_OP(max, __hmax) +MLX_DEFINE_BINARY_OP(min, __hmin) + +#undef MLX_DEFINE_BINARY_OP + +template +__forceinline__ __device__ T fmod(T x, T y) { + if constexpr (cuda::std::is_same_v) { + return __float2half(::fmod(__half2float(x), __half2float(y))); +#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800 + } else if constexpr (cuda::std::is_same_v) { + return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y))); +#endif + } else { + return ::fmod(x, y); + } +} + /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 35218f769..31f393bfa 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -54,45 +54,27 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(Add) NO_GPU(AddMM) -NO_GPU(ArcTan2) NO_GPU(ArgPartition) NO_GPU(ArgReduce) NO_GPU(ArgSort) -NO_GPU(BitwiseBinary) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled) NO_GPU(Convolution) -NO_GPU(Divide) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) -NO_GPU(Remainder) -NO_GPU(Equal) NO_GPU(FFT) NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) -NO_GPU(Greater) -NO_GPU(GreaterEqual) NO_GPU(Hadamard) -NO_GPU(Less) -NO_GPU(LessEqual) NO_GPU(Load) -NO_GPU(LogicalAnd) -NO_GPU(LogicalOr) -NO_GPU(LogAddExp) NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) NO_GPU(Matmul) -NO_GPU(Maximum) -NO_GPU(Minimum) -NO_GPU(Multiply) -NO_GPU(NotEqual) NO_GPU(Partition) -NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) @@ -104,7 +86,6 @@ NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU(Softmax) NO_GPU(Sort) -NO_GPU(Subtract) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky)