mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA backend: compile (#2276)
* CUDA backend: compile * Rename kernels/ to device/
This commit is contained in:
15
mlx/backend/cuda/device/arange.cuh
Normal file
15
mlx/backend/cuda/device/arange.cuh
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct Arange {
|
||||
const T start;
|
||||
const T step;
|
||||
|
||||
__device__ T operator()(uint32_t i) const {
|
||||
return start + i * step;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
278
mlx/backend/cuda/device/binary_ops.cuh
Normal file
278
mlx/backend/cuda/device/binary_ops.cuh
Normal file
@@ -0,0 +1,278 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return trunc(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
if constexpr (cuda::std::is_signed_v<T>) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
} else {
|
||||
return x % y;
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return x % y;
|
||||
} else {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r = r + y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
return x == y ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
|
||||
cuCimagf(x) == cuCimagf(y));
|
||||
} else {
|
||||
return x == y || (isnan(x) && isnan(y));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if (isnan(x) || isnan(y)) {
|
||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
T maxval = max(x, y);
|
||||
T minval = min(x, y);
|
||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return max(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return min(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
|
||||
} else {
|
||||
return x != y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
__device__ T operator()(T base, T exp) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto x_theta = atan2f(base.y, base.x);
|
||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||
auto phase = exp.y * x_ln_r + exp.x * x_theta;
|
||||
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
|
||||
} else {
|
||||
return powf(base, exp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T y, T x) {
|
||||
return atan2f(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivMod {
|
||||
template <typename T>
|
||||
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
|
||||
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
59
mlx/backend/cuda/device/cast_op.cuh
Normal file
59
mlx/backend/cuda/device/cast_op.cuh
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// An op that does static_cast, with custom conversions for some types.
|
||||
template <typename SrcT, typename DstT, typename = void>
|
||||
struct CastOp {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return static_cast<DstT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
// Converting a complex number to real number discards the imaginary part.
|
||||
template <typename DstT>
|
||||
struct CastOp<
|
||||
cuComplex,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
|
||||
|
||||
__device__ DstT operator()(cuComplex x) {
|
||||
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
|
||||
return static_cast<DstT>(cuCrealf(x));
|
||||
}
|
||||
};
|
||||
|
||||
// Allow converting a real number to complex number.
|
||||
template <typename SrcT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
cuComplex,
|
||||
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
|
||||
|
||||
__device__ cuComplex operator()(SrcT x) {
|
||||
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
|
||||
return cuComplex{static_cast<float>(x), 0};
|
||||
}
|
||||
};
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||
return it;
|
||||
} else {
|
||||
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
12
mlx/backend/cuda/device/config.h
Normal file
12
mlx/backend/cuda/device/config.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file is used by both CUDA kernel code and host-only C++ code.
|
||||
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
||||
240
mlx/backend/cuda/device/cucomplex_math.cuh
Normal file
240
mlx/backend/cuda/device/cucomplex_math.cuh
Normal file
@@ -0,0 +1,240 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2017-2024 The Simons Foundation, Inc.
|
||||
//
|
||||
// FINUFFT is licensed under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance with the
|
||||
// License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
|
||||
// This header provides some helper functions for cuComplex types.
|
||||
// It mainly wraps existing CUDA implementations to provide operator overloads
|
||||
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
|
||||
// all provided by CUDA
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCadd(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCsub(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCmul(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCdiv(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
|
||||
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
|
||||
return make_cuDoubleComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
|
||||
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(double a, const cuDoubleComplex& b) {
|
||||
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
|
||||
return make_cuDoubleComplex(
|
||||
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCaddf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCsubf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCmulf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCdivf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
|
||||
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
|
||||
return make_cuFloatComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
|
||||
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(float a, const cuFloatComplex& b) {
|
||||
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
|
||||
return make_cuFloatComplex(
|
||||
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
|
||||
}
|
||||
194
mlx/backend/cuda/device/fp16_math.cuh
Normal file
194
mlx/backend/cuda/device/fp16_math.cuh
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return ::NAME(__half2float(x)); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return ::NAME(__bfloat162float(x)); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
|
||||
MLX_DEFINE_UNARY_OP(abs, __habs)
|
||||
MLX_DEFINE_UNARY_OP(ceil, hceil)
|
||||
MLX_DEFINE_UNARY_OP(cos, hcos)
|
||||
MLX_DEFINE_UNARY_OP(exp, hexp)
|
||||
MLX_DEFINE_UNARY_OP(floor, hfloor)
|
||||
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
|
||||
MLX_DEFINE_UNARY_OP(log, hlog)
|
||||
MLX_DEFINE_UNARY_OP(log2, hlog2)
|
||||
MLX_DEFINE_UNARY_OP(log10, hlog10)
|
||||
MLX_DEFINE_UNARY_OP(rint, hrint)
|
||||
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
|
||||
MLX_DEFINE_UNARY_OP(sin, hsin)
|
||||
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
|
||||
#if __CUDA_ARCH__ >= 1280
|
||||
MLX_DEFINE_UNARY_OP(tanh, htanh)
|
||||
#else
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
|
||||
#endif
|
||||
|
||||
#undef MLX_DEFINE_UNARY_OP
|
||||
#undef MLX_DEFINE_UNARY_OP_FALLBCK
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Binary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
MLX_DEFINE_BINARY_OP(max, __hmax)
|
||||
MLX_DEFINE_BINARY_OP(min, __hmin)
|
||||
|
||||
#undef MLX_DEFINE_BINARY_OP
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T fmod(T x, T y) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return __float2half(::fmod(__half2float(x), __half2float(y)));
|
||||
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
|
||||
#endif
|
||||
} else {
|
||||
return ::fmod(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
349
mlx/backend/cuda/device/unary_ops.cuh
Normal file
349
mlx/backend/cuda/device/unary_ops.cuh
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
|
||||
} else {
|
||||
return abs(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return ~x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return ceil(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
return {cuCrealf(x), -cuCimagf(x)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return cos(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return cosh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erf(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erf(__bfloat162float(x));
|
||||
} else {
|
||||
return erf(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erfinv(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erfinv(__bfloat162float(x));
|
||||
} else {
|
||||
return erfinv(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto m = exp(cuCrealf(x));
|
||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return expm1(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return expm1(__bfloat162float(x));
|
||||
} else {
|
||||
return expm1(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return floor(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCimagf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log2(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log10(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
__device__ bool operator()(bool x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return 0 - x;
|
||||
} else {
|
||||
return -x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCrealf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
|
||||
} else {
|
||||
return rint(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x != 0;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
|
||||
return x;
|
||||
} else {
|
||||
return x / Abs()(x);
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||
} else {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return sin(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return sinh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tan_a = tan(cuCrealf(x));
|
||||
float tanh_b = tanh(cuCimagf(x));
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
} else {
|
||||
return tan(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tanh_a = tanh(cuCrealf(x));
|
||||
float tan_b = tan(cuCimagf(x));
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
} else {
|
||||
return tanh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
298
mlx/backend/cuda/device/utils.cuh
Normal file
298
mlx/backend/cuda/device/utils.cuh
Normal file
@@ -0,0 +1,298 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct Limits {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, __half> ||
|
||||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#else
|
||||
return -cuda::std::numeric_limits<float>::infinity();
|
||||
#endif
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#else
|
||||
return cuda::std::numeric_limits<float>::lowest();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr __host__ __device__ bool max() {
|
||||
return true;
|
||||
}
|
||||
static constexpr __host__ __device__ bool min() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<cuComplex> {
|
||||
static constexpr __host__ __device__ cuComplex max() {
|
||||
return {Limits<float>::max(), Limits<float>::max()};
|
||||
}
|
||||
static constexpr __host__ __device__ cuComplex min() {
|
||||
return {Limits<float>::min(), Limits<float>::min()};
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Optimize when the ndim is known at compile time.
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
|
||||
IdxT loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
// Optimized version when ndim is larger than 4.
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
OffsetT offset{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int) {}
|
||||
|
||||
__device__ void next(const int*, const int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
Reference in New Issue
Block a user