CUDA backend: binary ops (#2259)

This commit is contained in:
Cheng 2025-06-10 22:37:40 +09:00 committed by GitHub
parent 9ce77798b1
commit 7ebb2e0193
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 753 additions and 19 deletions

View File

@ -6,6 +6,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp

305
mlx/backend/cuda/binary.cu Normal file
View File

@ -0,0 +1,305 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
return std::is_same_v<Out, bool>;
}
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
}
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
std::is_same_v<Op, BitwiseXor>) {
return std::is_same_v<In, Out> && std::is_integral_v<In>;
}
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > UINT32_MAX ||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
break;
}
}
} // namespace mlx::core

View File

@ -13,9 +13,40 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <fmt/format.h>
#include <cuda/cmath>
namespace mlx::core {
// Convert a number between 1~3 to constexpr.
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
switch (N) { \
case 1: { \
constexpr int NDIM = 1; \
__VA_ARGS__; \
break; \
} \
case 2: { \
constexpr int NDIM = 2; \
__VA_ARGS__; \
break; \
} \
case 3: { \
constexpr int NDIM = 3; \
__VA_ARGS__; \
break; \
} \
}
// Like MLX_SWITCH_ALL_TYPES but for booleans.
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
if (BOOL) { \
constexpr bool BOOL_ALIAS = true; \
__VA_ARGS__; \
} else { \
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
}
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
@ -66,4 +97,35 @@ dim3 get_2d_grid_dims(
const Strides& strides,
size_t divisor);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
} // namespace mlx::core

View File

@ -0,0 +1,278 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
namespace mlx::core::cu {
struct Add {
template <typename T>
__device__ T operator()(T x, T y) {
return x + y;
}
};
struct FloorDivide {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return x / y;
} else {
return trunc(x / y);
}
}
};
struct Divide {
template <typename T>
__device__ T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
if constexpr (cuda::std::is_signed_v<T>) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
} else {
return x % y;
}
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return x % y;
} else {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r = r + y;
}
return r;
}
}
};
struct Equal {
template <typename T>
__device__ bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
return x == y ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
cuCimagf(x) == cuCimagf(y));
} else {
return x == y || (isnan(x) && isnan(y));
}
}
};
struct Greater {
template <typename T>
__device__ bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
__device__ bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
__device__ T operator()(T x, T y) {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
};
struct Maximum {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return max(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x > y ? x : y;
} else {
if (isnan(x)) {
return x;
}
return x > y ? x : y;
}
}
};
struct Minimum {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return min(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x < y ? x : y;
} else {
if (isnan(x)) {
return x;
}
return x < y ? x : y;
}
}
};
struct Multiply {
template <typename T>
__device__ T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
} else {
return x != y;
}
}
};
struct Power {
template <typename T>
__device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
auto phase = exp.y * x_ln_r + exp.x * x_theta;
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
} else {
return powf(base, exp);
}
}
};
struct Subtract {
template <typename T>
__device__ T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
__device__ T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
__device__ T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
__device__ T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
__device__ T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
__device__ T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
__device__ T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
__device__ T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
__device__ T operator()(T y, T x) {
return atan2f(y, x);
}
};
struct DivMod {
template <typename T>
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
return {FloorDivide{}(x, y), Remainder{}(x, y)};
};
};
} // namespace mlx::core::cu

View File

@ -81,6 +81,52 @@ MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
#undef MLX_DEFINE_UNARY_OP
#undef MLX_DEFINE_UNARY_OP_FALLBCK
///////////////////////////////////////////////////////////////////////////////
// Binary ops for half types.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x, T y) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x, y); \
} else { \
return ::NAME(x, y); \
} \
}
#else
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
template <typename T> \
__forceinline__ __device__ auto NAME(T x, T y) { \
if constexpr (cuda::std::is_same_v<T, __half>) { \
return HALF_OP(x, y); \
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
return HALF_OP(x, y); \
} else { \
return ::NAME(x, y); \
} \
}
#endif
MLX_DEFINE_BINARY_OP(max, __hmax)
MLX_DEFINE_BINARY_OP(min, __hmin)
#undef MLX_DEFINE_BINARY_OP
template <typename T>
__forceinline__ __device__ T fmod(T x, T y) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return __float2half(::fmod(__half2float(x), __half2float(y)));
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
#endif
} else {
return ::fmod(x, y);
}
}
///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types.
///////////////////////////////////////////////////////////////////////////////

View File

@ -11,6 +11,7 @@
#include <cuComplex.h>
#include <cuda/std/array>
#include <cuda/std/limits>
#include <cuda/std/tuple>
namespace mlx::core::cu {
@ -40,4 +41,64 @@ elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
return loc;
}
// Optimize when the ndim is known at compile time.
template <int NDIM, typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
IdxT loc = 0;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <int NDIM, typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
IdxT elem,
const int* shape,
const int64_t* a_strides,
const int64_t* b_strides) {
IdxT a_loc = 0;
IdxT b_loc = 0;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
}
// Optimized version when ndim is larger than 4.
template <typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
for (int i = ndim - 1; i >= 3; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT elem,
const int* shape,
const int64_t* a_strides,
const int64_t* b_strides,
int ndim) {
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
for (int i = ndim - 1; i >= 3; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
}
} // namespace mlx::core::cu

View File

@ -71,43 +71,25 @@ bool fast::ScaledDotProductAttention::use_fallback(
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(Add)
NO_GPU(ArcTan2)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(BitwiseBinary)
NO_GPU(BlockMaskedMM)
NO_GPU_MULTI(Compiled)
NO_GPU(Convolution)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(FFT)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(NotEqual)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
@ -119,7 +101,6 @@ NO_GPU(Select)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)