From 21c4a92ec13242aaf55e3684106f1a8590816d5c Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 13 Apr 2025 23:51:11 +0000 Subject: [PATCH] CUDA backend: binary ops --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/binary.cu | 305 ++++++++++++++++++++++++ mlx/backend/cuda/kernel_utils.cuh | 62 +++++ mlx/backend/cuda/kernels/binary_ops.cuh | 278 +++++++++++++++++++++ mlx/backend/cuda/kernels/fp16_math.cuh | 46 ++++ mlx/backend/cuda/kernels/utils.cuh | 61 +++++ mlx/backend/cuda/primitives.cu | 19 -- 7 files changed, 753 insertions(+), 19 deletions(-) create mode 100644 mlx/backend/cuda/binary.cu create mode 100644 mlx/backend/cuda/kernels/binary_ops.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index cd73843bf..c813f8fd4 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu new file mode 100644 index 000000000..360772998 --- /dev/null +++ b/mlx/backend/cuda/binary.cu @@ -0,0 +1,305 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/binary_ops.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[0]); + } +} + +template +__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[index]); + } +} + +template +__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[0]); + } +} + +template +__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[index]); + } +} + +template +__global__ void binary_g_nd( + const In* a, + const In* b, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +constexpr bool supports_binary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && + (is_floating_v || std::is_same_v); + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && std::is_integral_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out = outputs[0]; + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + bool large = a.data_size() > UINT32_MAX || + b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = + &cu::binary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, LARGE); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +template +void binary_op_gpu( + const std::vector& inputs, + array& out, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + std::vector outputs{out}; + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +#define BINARY_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = outputs[0].primitive().stream(); \ + binary_op_gpu(inputs, outputs, get_primitive_string(this), s); \ + } + +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU(Remainder) +BINARY_GPU(Equal) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(LogAddExp) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Subtract) + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, op, s); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 6430b8c59..aeb065206 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -13,9 +13,40 @@ #include #include #include +#include namespace mlx::core { +// Convert a number between 1~3 to constexpr. +#define MLX_SWITCH_1_2_3(N, NDIM, ...) \ + switch (N) { \ + case 1: { \ + constexpr int NDIM = 1; \ + __VA_ARGS__; \ + break; \ + } \ + case 2: { \ + constexpr int NDIM = 2; \ + __VA_ARGS__; \ + break; \ + } \ + case 3: { \ + constexpr int NDIM = 3; \ + __VA_ARGS__; \ + break; \ + } \ + } + +// Like MLX_SWITCH_ALL_TYPES but for booleans. +#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ + if (BOOL) { \ + constexpr bool BOOL_ALIAS = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool BOOL_ALIAS = false; \ + __VA_ARGS__; \ + } + // Maps CPU types to CUDA types. template struct CTypeToCudaType { @@ -66,4 +97,35 @@ dim3 get_2d_grid_dims( const Strides& strides, size_t divisor); +// Return a block size that achieves maximum potential occupancy for kernel. +template +inline uint max_occupancy_block_dim(T kernel) { + int _, block_dim; + CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + return block_dim; +} + +// Get the num_blocks and block_dims that maximize occupancy for |kernel|, +// assuming each thread handles |work_per_thread| elements of |arr|. +template +inline std::tuple get_launch_args( + T kernel, + const array& arr, + bool large, + int work_per_thread = 1) { + size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); + uint block_dim = max_occupancy_block_dim(kernel); + if (block_dim > nthreads) { + block_dim = nthreads; + } + dim3 num_blocks; + if (large) { + num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread); + num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); + } else { + num_blocks.x = cuda::ceil_div(nthreads, block_dim); + } + return std::make_tuple(num_blocks, block_dim); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/binary_ops.cuh b/mlx/backend/cuda/kernels/binary_ops.cuh new file mode 100644 index 000000000..3bc30eb02 --- /dev/null +++ b/mlx/backend/cuda/kernels/binary_ops.cuh @@ -0,0 +1,278 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/kernels/fp16_math.cuh" + +#include +#include + +namespace mlx::core::cu { + +struct Add { + template + __device__ T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return x / y; + } else { + return trunc(x / y); + } + } +}; + +struct Divide { + template + __device__ T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + if constexpr (cuda::std::is_signed_v) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } else { + return x % y; + } + } else if constexpr (cuda::std::is_same_v) { + return x % y; + } else { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r = r + y; + } + return r; + } + } +}; + +struct Equal { + template + __device__ bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return x == y || + (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) && + isnan(cuCimagf(y))) || + (cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) && + isnan(cuCimagf(y))) || + (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && + cuCimagf(x) == cuCimagf(y)); + } else { + return x == y || (isnan(x) && isnan(y)); + } + } +}; + +struct Greater { + template + __device__ bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + __device__ bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + __device__ bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + __device__ T operator()(T x, T y) { + if (isnan(x) || isnan(y)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + T maxval = max(x, y); + T minval = min(x, y); + return (minval == -cuda::std::numeric_limits::infinity() || + maxval == cuda::std::numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1p(expf(minval - maxval))); + }; +}; + +struct Maximum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return max(x, y); + } else if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + return x; + } + return x > y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x > y ? x : y; + } + } +}; + +struct Minimum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return min(x, y); + } else if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + return x; + } + return x < y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x < y ? x : y; + } + } +}; + +struct Multiply { + template + __device__ T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y); + } else { + return x != y; + } + } +}; + +struct Power { + template + __device__ T operator()(T base, T exp) { + if constexpr (cuda::std::is_integral_v) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } else if constexpr (cuda::std::is_same_v) { + auto x_theta = atan2f(base.y, base.x); + auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); + auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); + auto phase = exp.y * x_ln_r + exp.x * x_theta; + return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase)); + } else { + return powf(base, exp); + } + } +}; + +struct Subtract { + template + __device__ T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + __device__ T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + __device__ T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + __device__ T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + __device__ T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + __device__ T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + __device__ T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + __device__ T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + __device__ T operator()(T y, T x) { + return atan2f(y, x); + } +}; + +struct DivMod { + template + __device__ cuda::std::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index cf5def4db..f6fa17bb9 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -81,6 +81,52 @@ MLX_DEFINE_UNARY_OP_FALLBCK(tanh) #undef MLX_DEFINE_UNARY_OP #undef MLX_DEFINE_UNARY_OP_FALLBCK +/////////////////////////////////////////////////////////////////////////////// +// Binary ops for half types. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#else +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#endif + +MLX_DEFINE_BINARY_OP(max, __hmax) +MLX_DEFINE_BINARY_OP(min, __hmin) + +#undef MLX_DEFINE_BINARY_OP + +template +__forceinline__ __device__ T fmod(T x, T y) { + if constexpr (cuda::std::is_same_v) { + return __float2half(::fmod(__half2float(x), __half2float(y))); +#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800 + } else if constexpr (cuda::std::is_same_v) { + return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y))); +#endif + } else { + return ::fmod(x, y); + } +} + /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh index 4d69b7356..16957d132 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -11,6 +11,7 @@ #include #include #include +#include namespace mlx::core::cu { @@ -40,4 +41,64 @@ elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { return loc; } +// Optimize when the ndim is known at compile time. +template +inline __host__ __device__ IdxT +elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) { + IdxT loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides) { + IdxT a_loc = 0; + IdxT b_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc); +} + +// Optimized version when ndim is larger than 4. +template +inline __host__ __device__ IdxT +elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = elem_to_loc_nd<3>(elem, shape, strides); + for (int i = ndim - 1; i >= 3; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + int ndim) { + auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides); + for (int i = ndim - 1; i >= 3; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc); +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 3d9186892..2c3a73c42 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -71,43 +71,25 @@ bool fast::ScaledDotProductAttention::use_fallback( throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(Add) -NO_GPU(ArcTan2) NO_GPU(ArgPartition) NO_GPU(ArgReduce) NO_GPU(ArgSort) -NO_GPU(BitwiseBinary) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled) NO_GPU(Convolution) -NO_GPU(Divide) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) -NO_GPU(Remainder) -NO_GPU(Equal) NO_GPU(FFT) NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) -NO_GPU(Greater) -NO_GPU(GreaterEqual) NO_GPU(Hadamard) -NO_GPU(Less) -NO_GPU(LessEqual) NO_GPU(Load) -NO_GPU(LogicalAnd) -NO_GPU(LogicalOr) -NO_GPU(LogAddExp) NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) -NO_GPU(Maximum) -NO_GPU(Minimum) -NO_GPU(Multiply) -NO_GPU(NotEqual) NO_GPU(Partition) -NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) @@ -119,7 +101,6 @@ NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU(Softmax) NO_GPU(Sort) -NO_GPU(Subtract) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky)