From f8bad606099169e486ef5c8761a4f2a9d158245b Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 9 Jun 2025 22:45:08 +0900 Subject: [PATCH] CUDA backend: unary ops (#2158) --- mlx/backend/common/copy.h | 4 +- mlx/backend/common/unary.h | 26 ++ mlx/backend/cpu/unary.h | 21 +- mlx/backend/cuda/CMakeLists.txt | 1 + .../cuda/iterators/general_iterator.cuh | 121 ++++++ mlx/backend/cuda/kernel_utils.cuh | 20 + mlx/backend/cuda/kernels/cucomplex_math.cuh | 240 ++++++++++++ mlx/backend/cuda/kernels/fp16_math.cuh | 72 ++++ mlx/backend/cuda/kernels/unary_ops.cuh | 349 ++++++++++++++++++ mlx/backend/cuda/kernels/utils.cuh | 43 +++ mlx/backend/cuda/primitives.cu | 32 -- mlx/backend/cuda/unary.cu | 196 ++++++++++ mlx/backend/metal/unary.cpp | 19 +- 13 files changed, 1074 insertions(+), 70 deletions(-) create mode 100644 mlx/backend/common/unary.h create mode 100644 mlx/backend/cuda/iterators/general_iterator.cuh create mode 100644 mlx/backend/cuda/kernels/cucomplex_math.cuh create mode 100644 mlx/backend/cuda/kernels/unary_ops.cuh create mode 100644 mlx/backend/cuda/kernels/utils.cuh create mode 100644 mlx/backend/cuda/unary.cu diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index 0c9f28c94..c23d2e79a 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/array.h" +#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { if (ctype == CopyType::Vector) { // If the input is donateable, we are doing a vector copy and the types // have the same size, then the input buffer can hold the output. - if (in.is_donatable() && in.itemsize() == out.itemsize()) { + if (is_donatable(in, out)) { out.copy_shared_buffer(in); return true; } else { diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h new file mode 100644 index 000000000..a27a1f45c --- /dev/null +++ b/mlx/backend/common/unary.h @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +inline void set_unary_output_data(const array& in, array& out) { + if (in.flags().contiguous) { + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + out.set_data(allocator::malloc(out.nbytes())); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index fa539541c..14c1dd479 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -2,32 +2,13 @@ #pragma once -#include "mlx/allocator.h" -#include "mlx/array.h" -#include "mlx/backend/common/utils.h" +#include "mlx/backend/common/unary.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/utils.h" namespace mlx::core { -void set_unary_output_data(const array& in, array& out) { - if (in.flags().contiguous) { - if (is_donatable(in, out)) { - out.copy_shared_buffer(in); - } else { - auto size = in.data_size(); - out.set_data( - allocator::malloc(size * out.itemsize()), - size, - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc(out.nbytes())); - } -} - template void unary_op(const T* a, U* out, size_t shape, size_t stride) { for (size_t i = 0; i < shape; i += 1) { diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 9eaf2a6c7..cd73843bf 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -15,6 +15,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) diff --git a/mlx/backend/cuda/iterators/general_iterator.cuh b/mlx/backend/cuda/iterators/general_iterator.cuh new file mode 100644 index 000000000..3c8c098c3 --- /dev/null +++ b/mlx/backend/cuda/iterators/general_iterator.cuh @@ -0,0 +1,121 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/cuda/kernel_utils.cuh" + +namespace mlx::core::cu { + +// Iterating non-contiguous array. +template +class general_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ general_iterator( + Iterator it, + IdxT index, + int ndim, + Shape shape, + Strides strides) + : super_t(it), + index_(index), + ndim_(ndim), + shape_(cuda::std::move(shape)), + strides_(cuda::std::move(strides)) {} + + __host__ __device__ IdxT index() const { + return index_; + } + + __host__ __device__ const Shape& shape() const { + return shape_; + } + + __host__ __device__ const Strides& strides() const { + return strides_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const general_iterator& other) const { + return this->base() == other.base() && this->index() == other.index(); + } + + __host__ __device__ void advance(difference_type n) { + this->index_ += n; + } + + __host__ __device__ void increment() { + this->index_ += 1; + } + + __host__ __device__ void decrement() { + this->index_ -= 1; + } + + __host__ __device__ difference_type + distance_to(const general_iterator& other) const { + _CCCL_ASSERT( + this->base() == other.base(), + "Underlying iterator must point to same base iterator"); + return other.index() - this->index(); + } + + // The dereference is device-only to avoid accidental running in host. + __device__ typename super_t::reference dereference() const { + IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_); + return *(this->base() + offset); + } + + IdxT index_; + int ndim_; + Shape shape_; + Strides strides_; +}; + +template +__host__ __device__ auto make_general_iterator( + Iterator it, + IdxT index, + int ndim, + Shape shape, + Strides strides) { + return general_iterator( + it, index, ndim, cuda::std::move(shape), cuda::std::move(strides)); +} + +template +auto make_general_iterator( + Iterator it, + const std::vector& shape, + const std::vector& strides) { + return make_general_iterator( + it, 0, shape.size(), const_param(shape), const_param(strides)); +} + +template +auto make_general_iterators( + Iterator it, + IdxT size, + const std::vector& shape, + const std::vector& strides) { + auto ndim = shape.size(); + auto shape_arg = const_param(shape); + auto strides_arg = const_param(strides); + return std::make_pair( + make_general_iterator(it, 0, ndim, shape_arg, strides_arg), + make_general_iterator(it, size, ndim, shape_arg, strides_arg)); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 67ac47449..6430b8c59 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -7,10 +7,12 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/kernels/utils.cuh" #include #include #include +#include namespace mlx::core { @@ -38,6 +40,24 @@ struct CTypeToCudaType { template using cuda_type_t = typename CTypeToCudaType::type; +// Type traits for detecting floating numbers. +template +inline constexpr bool is_floating_v = + cuda::std::is_same_v || cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; + +// Utility to copy data from vector to array in host. +template +inline cuda::std::array const_param(const std::vector& vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + cuda::std::array result; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + // Compute the grid and block dimensions, check backend/common/utils.h for docs. dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); diff --git a/mlx/backend/cuda/kernels/cucomplex_math.cuh b/mlx/backend/cuda/kernels/cucomplex_math.cuh new file mode 100644 index 000000000..612650c06 --- /dev/null +++ b/mlx/backend/cuda/kernels/cucomplex_math.cuh @@ -0,0 +1,240 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2017-2024 The Simons Foundation, Inc. +// +// FINUFFT is licensed under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance with the +// License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h + +#pragma once + +#include + +// This header provides some helper functions for cuComplex types. +// It mainly wraps existing CUDA implementations to provide operator overloads +// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are +// all provided by CUDA + +__forceinline__ __host__ __device__ cuDoubleComplex +operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCadd(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCsub(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCmul(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCdiv(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) { + double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b)); + double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b)); + return make_cuDoubleComplex(r, i); +} + +__forceinline__ __host__ __device__ bool operator==( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b); +} + +__forceinline__ __host__ __device__ bool operator!=( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return !(a == b); +} + +__forceinline__ __host__ __device__ bool operator>( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a)); + double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b)); + return mag_a > mag_b; +} + +__forceinline__ __host__ __device__ bool operator>=( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return a > b || a == b; +} + +__forceinline__ __host__ __device__ bool operator<( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return b > a; +} + +__forceinline__ __host__ __device__ bool operator<=( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return b > a || a == b; +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator+(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator+(double a, const cuDoubleComplex& b) { + return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator-(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator-(double a, const cuDoubleComplex& b) { + return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator*(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator*(double a, const cuDoubleComplex& b) { + return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator/(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator/(double a, const cuDoubleComplex& b) { + double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b); + return make_cuDoubleComplex( + (a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator+(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCaddf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator-(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCsubf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator*(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCmulf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator/(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCdivf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator%(const cuFloatComplex& a, const cuFloatComplex& b) { + float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b)); + float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b)); + return make_cuFloatComplex(r, i); +} + +__forceinline__ __host__ __device__ bool operator==( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b); +} + +__forceinline__ __host__ __device__ bool operator!=( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return !(a == b); +} + +__forceinline__ __host__ __device__ bool operator>( + const cuFloatComplex& a, + const cuFloatComplex& b) { + float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a)); + float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b)); + return mag_a > mag_b; +} + +__forceinline__ __host__ __device__ bool operator>=( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return a > b || a == b; +} + +__forceinline__ __host__ __device__ bool operator<( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return b > a; +} + +__forceinline__ __host__ __device__ bool operator<=( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return b > a || a == b; +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator+(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator+(float a, const cuFloatComplex& b) { + return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator-(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator-(float a, const cuFloatComplex& b) { + return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator*(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator*(float a, const cuFloatComplex& b) { + return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator/(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator/(float a, const cuFloatComplex& b) { + float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b); + return make_cuFloatComplex( + (a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom); +} diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index edbd953de..cf5def4db 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -9,6 +9,78 @@ namespace mlx::core::cu { +/////////////////////////////////////////////////////////////////////////////// +// Unary ops for half types. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 +#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else { \ + return ::NAME(x); \ + } \ + } +#else +#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else { \ + return ::NAME(x); \ + } \ + } +#endif + +#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return ::NAME(__half2float(x)); \ + } else if constexpr (cuda::std::is_same_v) { \ + return ::NAME(__bfloat162float(x)); \ + } else { \ + return ::NAME(x); \ + } \ + } + +MLX_DEFINE_UNARY_OP(abs, __habs) +MLX_DEFINE_UNARY_OP(ceil, hceil) +MLX_DEFINE_UNARY_OP(cos, hcos) +MLX_DEFINE_UNARY_OP(exp, hexp) +MLX_DEFINE_UNARY_OP(floor, hfloor) +MLX_DEFINE_UNARY_OP(isnan, __hisnan) +MLX_DEFINE_UNARY_OP(log, hlog) +MLX_DEFINE_UNARY_OP(log2, hlog2) +MLX_DEFINE_UNARY_OP(log10, hlog10) +MLX_DEFINE_UNARY_OP(rint, hrint) +MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt) +MLX_DEFINE_UNARY_OP(sin, hsin) +MLX_DEFINE_UNARY_OP(sqrt, hsqrt) +MLX_DEFINE_UNARY_OP_FALLBCK(acos) +MLX_DEFINE_UNARY_OP_FALLBCK(acosh) +MLX_DEFINE_UNARY_OP_FALLBCK(asin) +MLX_DEFINE_UNARY_OP_FALLBCK(asinh) +MLX_DEFINE_UNARY_OP_FALLBCK(atan) +MLX_DEFINE_UNARY_OP_FALLBCK(atanh) +MLX_DEFINE_UNARY_OP_FALLBCK(cosh) +MLX_DEFINE_UNARY_OP_FALLBCK(log1p) +MLX_DEFINE_UNARY_OP_FALLBCK(sinh) +MLX_DEFINE_UNARY_OP_FALLBCK(tan) +#if __CUDA_ARCH__ >= 1280 +MLX_DEFINE_UNARY_OP(tanh, htanh) +#else +MLX_DEFINE_UNARY_OP_FALLBCK(tanh) +#endif + +#undef MLX_DEFINE_UNARY_OP +#undef MLX_DEFINE_UNARY_OP_FALLBCK + /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/kernels/unary_ops.cuh b/mlx/backend/cuda/kernels/unary_ops.cuh new file mode 100644 index 000000000..6637a6eeb --- /dev/null +++ b/mlx/backend/cuda/kernels/unary_ops.cuh @@ -0,0 +1,349 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/backend/cuda/kernels/utils.cuh" + +namespace mlx::core::cu { + +struct Abs { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_unsigned_v) { + return x; + } else if constexpr (cuda::std::is_same_v) { + return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0}; + } else { + return abs(x); + } + } +}; + +struct ArcCos { + template + __device__ T operator()(T x) { + return acos(x); + } +}; + +struct ArcCosh { + template + __device__ T operator()(T x) { + return acosh(x); + } +}; + +struct ArcSin { + template + __device__ T operator()(T x) { + return asin(x); + } +}; + +struct ArcSinh { + template + __device__ T operator()(T x) { + return asinh(x); + } +}; + +struct ArcTan { + template + __device__ T operator()(T x) { + return atan(x); + } +}; + +struct ArcTanh { + template + __device__ T operator()(T x) { + return atanh(x); + } +}; + +struct BitwiseInvert { + template + __device__ T operator()(T x) { + return ~x; + } +}; + +struct Ceil { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_integral_v) { + return x; + } else { + return ceil(x); + } + } +}; + +struct Conjugate { + __device__ cuComplex operator()(cuComplex x) { + return {cuCrealf(x), -cuCimagf(x)}; + } +}; + +struct Cos { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + cos(cuCrealf(x)) * cosh(cuCimagf(x)), + -sin(cuCrealf(x)) * sinh(cuCimagf(x))}; + } else { + return cos(x); + } + } +}; + +struct Cosh { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + cosh(cuCrealf(x)) * cos(cuCimagf(x)), + sinh(cuCrealf(x)) * sin(cuCimagf(x))}; + } else { + return cosh(x); + } + } +}; + +struct Erf { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return erf(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return erf(__bfloat162float(x)); + } else { + return erf(x); + } + } +}; + +struct ErfInv { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return erfinv(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return erfinv(__bfloat162float(x)); + } else { + return erfinv(x); + } + } +}; + +struct Exp { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + auto m = exp(cuCrealf(x)); + return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))}; + } else { + return exp(x); + } + } +}; + +struct Expm1 { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return expm1(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return expm1(__bfloat162float(x)); + } else { + return expm1(x); + } + } +}; + +struct Floor { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_integral_v) { + return x; + } else { + return floor(x); + } + } +}; + +struct Imag { + __device__ float operator()(cuComplex x) { + return cuCimagf(x); + } +}; + +struct Log { + template + __device__ T operator()(T x) { + return log(x); + } +}; + +struct Log2 { + template + __device__ T operator()(T x) { + return log2(x); + } +}; + +struct Log10 { + template + __device__ T operator()(T x) { + return log10(x); + } +}; + +struct Log1p { + template + __device__ T operator()(T x) { + return log1p(x); + } +}; + +struct LogicalNot { + __device__ bool operator()(bool x) { + return !x; + } +}; + +struct Negative { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return 0 - x; + } else { + return -x; + } + } +}; + +struct Real { + __device__ float operator()(cuComplex x) { + return cuCrealf(x); + } +}; + +struct Round { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return {rint(cuCrealf(x)), rint(cuCimagf(x))}; + } else { + return rint(x); + } + } +}; + +struct Rsqrt { + template + __device__ T operator()(T x) { + return rsqrt(x); + } +}; + +struct Sigmoid { + template + __device__ T operator()(T x) { + T y = 1 / (1 + exp(-abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_unsigned_v) { + return x != 0; + } else if constexpr (cuda::std::is_same_v) { + if (cuCrealf(x) == 0 && cuCimagf(x) == 0) { + return x; + } else { + return x / Abs()(x); + } + } else if constexpr (cuda::std::is_same_v) { + return static_cast((x > T(0.f)) - (x < T(0.f))); + } else { + return (x > T(0)) - (x < T(0)); + } + } +}; + +struct Sin { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + sin(cuCrealf(x)) * cosh(cuCimagf(x)), + cos(cuCrealf(x)) * sinh(cuCimagf(x))}; + } else { + return sin(x); + } + } +}; + +struct Sinh { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + sinh(cuCrealf(x)) * cos(cuCimagf(x)), + cosh(cuCrealf(x)) * sin(cuCimagf(x))}; + } else { + return sinh(x); + } + } +}; + +struct Square { + template + __device__ T operator()(T x) { + return x * x; + } +}; + +struct Sqrt { + template + __device__ T operator()(T x) { + return sqrt(x); + } +}; + +struct Tan { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + float tan_a = tan(cuCrealf(x)); + float tanh_b = tanh(cuCimagf(x)); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + } else { + return tan(x); + } + } +}; + +struct Tanh { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + float tanh_a = tanh(cuCrealf(x)); + float tan_b = tan(cuCimagf(x)); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + } else { + return tanh(x); + } + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh new file mode 100644 index 000000000..4d69b7356 --- /dev/null +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -0,0 +1,43 @@ +// Copyright © 2025 Apple Inc. + +// This file must not include any host-only code, utilies that work under both +// host and device can be put here. +// +// See more about the requirements at: +// https://docs.nvidia.com/cuda/nvrtc/#language + +#pragma once + +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// CUDA kernel utils +/////////////////////////////////////////////////////////////////////////////// + +// To pass shape/strides to kernels via constant memory, their size must be +// known at compile time. +#define MAX_NDIM 8 + +using Shape = cuda::std::array; +using Strides = cuda::std::array; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +template +inline __host__ __device__ IdxT +elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index fad2d76d3..3d9186892 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -71,39 +71,22 @@ bool fast::ScaledDotProductAttention::use_fallback( throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(Abs) NO_GPU(Add) -NO_GPU(ArcCos) -NO_GPU(ArcCosh) -NO_GPU(ArcSin) -NO_GPU(ArcSinh) -NO_GPU(ArcTan) NO_GPU(ArcTan2) -NO_GPU(ArcTanh) NO_GPU(ArgPartition) NO_GPU(ArgReduce) NO_GPU(ArgSort) NO_GPU(BitwiseBinary) -NO_GPU(BitwiseInvert) NO_GPU(BlockMaskedMM) -NO_GPU(Ceil) NO_GPU_MULTI(Compiled) -NO_GPU(Conjugate) NO_GPU(Convolution) -NO_GPU(Cos) -NO_GPU(Cosh) NO_GPU(Divide) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(Remainder) NO_GPU(Equal) -NO_GPU(Erf) -NO_GPU(ErfInv) -NO_GPU(Exp) -NO_GPU(Expm1) NO_GPU(FFT) -NO_GPU(Floor) NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) @@ -111,13 +94,9 @@ NO_GPU(GatherQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Hadamard) -NO_GPU(Imag) NO_GPU(Less) NO_GPU(LessEqual) NO_GPU(Load) -NO_GPU(Log) -NO_GPU(Log1p) -NO_GPU(LogicalNot) NO_GPU(LogicalAnd) NO_GPU(LogicalOr) NO_GPU(LogAddExp) @@ -126,33 +105,22 @@ NO_GPU_MULTI(LUF) NO_GPU(Maximum) NO_GPU(Minimum) NO_GPU(Multiply) -NO_GPU(Negative) NO_GPU(NotEqual) NO_GPU(Partition) NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) -NO_GPU(Real) NO_GPU(Reduce) -NO_GPU(Round) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) -NO_GPU(Sigmoid) -NO_GPU(Sign) -NO_GPU(Sin) -NO_GPU(Sinh) NO_GPU(SliceUpdate) NO_GPU(Softmax) NO_GPU(Sort) -NO_GPU(Square) -NO_GPU(Sqrt) NO_GPU(Subtract) NO_GPU_MULTI(SVD) -NO_GPU(Tan) -NO_GPU(Tanh) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eig) diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu new file mode 100644 index 000000000..0ee31ee28 --- /dev/null +++ b/mlx/backend/cuda/unary.cu @@ -0,0 +1,196 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/backend/cuda/kernels/unary_ops.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +template +constexpr bool supports_unary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && !std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v && + (is_floating_v || std::is_same_v); + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + auto& in = inputs[0]; + if (in.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (cu::supports_unary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + auto policy = cu::thrust_policy(stream); + auto in_ptr = thrust::device_pointer_cast(in.data()); + auto out_ptr = thrust::device_pointer_cast(out.data()); + if (in.flags().contiguous) { + thrust::transform( + policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); + } else { + auto [shape, strides] = collapse_contiguous_dims(in); + auto [in_begin, in_end] = cu::make_general_iterators( + in_ptr, in.data_size(), shape, strides); + thrust::transform(policy, in_begin, in_end, out_ptr, Op()); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void unary_op_gpu( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + set_unary_output_data(inputs[0], out); + unary_op_gpu_inplace(inputs, out, op, s); +} + +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +UNARY_GPU(Abs) +UNARY_GPU(ArcCos) +UNARY_GPU(ArcCosh) +UNARY_GPU(ArcSin) +UNARY_GPU(ArcSinh) +UNARY_GPU(ArcTan) +UNARY_GPU(ArcTanh) +UNARY_GPU(BitwiseInvert) +UNARY_GPU(Ceil) +UNARY_GPU(Conjugate) +UNARY_GPU(Cos) +UNARY_GPU(Cosh) +UNARY_GPU(Erf) +UNARY_GPU(ErfInv) +UNARY_GPU(Exp) +UNARY_GPU(Expm1) +UNARY_GPU(Floor) +UNARY_GPU(Imag) +UNARY_GPU(Log1p) +UNARY_GPU(LogicalNot) +UNARY_GPU(Negative) +UNARY_GPU(Real) +UNARY_GPU(Sigmoid) +UNARY_GPU(Sign) +UNARY_GPU(Sin) +UNARY_GPU(Sinh) +UNARY_GPU(Square) +UNARY_GPU(Tan) +UNARY_GPU(Tanh) + +void Log::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Log::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (base_) { + case Base::e: + unary_op_gpu(inputs, out, op, s); + break; + case Base::two: + unary_op_gpu(inputs, out, op, s); + break; + case Base::ten: + unary_op_gpu(inputs, out, op, s); + break; + } +} + +void Round::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Round::eval_gpu"); + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto& s = out.primitive().stream(); + if (issubdtype(in.dtype(), inexact)) { + unary_op_gpu(inputs, out, get_primitive_string(this), s); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sort::eval_gpu"); + auto& s = out.primitive().stream(); + if (recip_) { + unary_op_gpu(inputs, out, "Rsqrt", s); + } else { + unary_op_gpu(inputs, out, "Sqrt", s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 850c17376..0b118b72f 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/common/utils.h" + +#include "mlx/backend/common/unary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -99,21 +100,7 @@ void unary_op_gpu( array& out, const std::string op, const Stream& s) { - auto& in = inputs[0]; - bool contig = in.flags().contiguous; - if (contig) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.copy_shared_buffer(in); - } else { - out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc(out.nbytes())); - } + set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); }