// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/arange.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/distributed/primitives.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include #include #include #include namespace mlx::core { void Arange::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Arange::eval_gpu"); assert(inputs.size() == 0); out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } auto& s = stream(); auto& encoder = cu::get_command_encoder(s); encoder.set_output_array(out); encoder.launch_kernel([&, this](cudaStream_t stream) { MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { using OutType = cuda_type_t; CTYPE step = static_cast(start_ + step_) - static_cast(start_); thrust::transform( cu::thrust_policy(stream), thrust::counting_iterator(0), thrust::counting_iterator(out.data_size()), thrust::device_pointer_cast(out.data()), cu::Arange{ static_cast(start_), static_cast(step)}); }); }); } bool fast::ScaledDotProductAttention::use_fallback( const array& q, const array& k, const array& v, bool has_mask, bool has_arr_mask, bool do_causal, Stream s) { return true; } #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ } #define NO_GPU_USE_FALLBACK(func) \ bool func::use_fallback(Stream s) { \ return true; \ } \ NO_GPU_MULTI(func) #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ } NO_GPU(ArgPartition) NO_GPU(BlockMaskedMM) NO_GPU(Convolution) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) NO_GPU(Load) NO_GPU_MULTI(LUF) NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) } // namespace fast namespace distributed { NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) } // namespace distributed } // namespace mlx::core