From 0cae0bdac83bbf5b3d1da3ca53f1f7eb95981d30 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 7 May 2025 13:26:46 +0900 Subject: [PATCH] CUDA backend: backbone (#2075) --- CMakeLists.txt | 5 + mlx/CMakeLists.txt | 10 +- mlx/backend/cuda/CMakeLists.txt | 57 ++++++ mlx/backend/cuda/allocator.cpp | 154 ++++++++++++++ mlx/backend/cuda/allocator.h | 58 ++++++ mlx/backend/cuda/copy.cpp | 26 +++ mlx/backend/cuda/device.cpp | 117 +++++++++++ mlx/backend/cuda/device.h | 131 ++++++++++++ mlx/backend/cuda/dtype_utils.cuh | 35 ++++ mlx/backend/cuda/eval.cpp | 68 +++++++ mlx/backend/cuda/event.cu | 265 +++++++++++++++++++++++++ mlx/backend/cuda/event.h | 66 ++++++ mlx/backend/cuda/fence.cu | 70 +++++++ mlx/backend/cuda/kernels/arange.cuh | 15 ++ mlx/backend/cuda/kernels/fp16_math.cuh | 107 ++++++++++ mlx/backend/cuda/primitives.cu | 163 +++++++++++++++ mlx/backend/cuda/slicing.cpp | 15 ++ mlx/backend/cuda/utils.cpp | 26 +++ mlx/backend/cuda/utils.h | 36 ++++ mlx/backend/cuda/worker.cpp | 90 +++++++++ mlx/backend/cuda/worker.h | 68 +++++++ tests/CMakeLists.txt | 2 +- 22 files changed, 1582 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/CMakeLists.txt create mode 100644 mlx/backend/cuda/allocator.cpp create mode 100644 mlx/backend/cuda/allocator.h create mode 100644 mlx/backend/cuda/copy.cpp create mode 100644 mlx/backend/cuda/device.cpp create mode 100644 mlx/backend/cuda/device.h create mode 100644 mlx/backend/cuda/dtype_utils.cuh create mode 100644 mlx/backend/cuda/eval.cpp create mode 100644 mlx/backend/cuda/event.cu create mode 100644 mlx/backend/cuda/event.h create mode 100644 mlx/backend/cuda/fence.cu create mode 100644 mlx/backend/cuda/kernels/arange.cuh create mode 100644 mlx/backend/cuda/kernels/fp16_math.cuh create mode 100644 mlx/backend/cuda/primitives.cu create mode 100644 mlx/backend/cuda/slicing.cpp create mode 100644 mlx/backend/cuda/utils.cpp create mode 100644 mlx/backend/cuda/utils.h create mode 100644 mlx/backend/cuda/worker.cpp create mode 100644 mlx/backend/cuda/worker.h diff --git a/CMakeLists.txt b/CMakeLists.txt index e2002fc94..ab8aea443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) +option(MLX_BUILD_CUDA "Build cuda backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -83,6 +84,10 @@ if(MLX_BUILD_METAL) set(QUARTZ_LIB "-framework QuartzCore") endif() +if(MLX_BUILD_CUDA) + enable_language(CUDA) +endif() + if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 00898e73e..4ba9b33dd 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,10 +47,18 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) +endif() + +if(MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) +endif() + +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) +else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt new file mode 100644 index 000000000..54d651005 --- /dev/null +++ b/mlx/backend/cuda/CMakeLists.txt @@ -0,0 +1,57 @@ +# Filename rules in cuda backend: +# +# * Use .cu/.cuh if code contains device code, and .cpp/.h if not. +# * Device-only kernel code should be put in kernels/ subdir. +# * Files in kernels/ subdir should not include files outside. +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cu + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) + +target_compile_definitions(mlx PUBLIC MLX_USE_CUDA) + +# Enable defining device lambda functions. +target_compile_options(mlx + PRIVATE "$<$:--extended-lambda>") + +# Compute capability 7 is required for synchronization between CPU/GPU with +# managed memory. TODO: Add more architectures for potential performance gain. +set(MLX_CUDA_ARCHITECTURES + "75;80" + CACHE STRING "CUDA architectures") +message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") +set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES + "${MLX_CUDA_ARCHITECTURES}") + +# Use fixed version of CCCL. +FetchContent_Declare( + cccl + URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") +FetchContent_MakeAvailable(cccl) +target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include") + +# Use fixed version of NVTX. +FetchContent_Declare( + nvtx3 + GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git + GIT_TAG v3.1.1 + GIT_SHALLOW TRUE + SOURCE_SUBDIR c EXCLUDE_FROM_ALL) +FetchContent_MakeAvailable(nvtx3) +target_link_libraries(mlx PUBLIC $) + +# Make cuda runtime APIs available in non-cuda files. +find_package(CUDAToolkit REQUIRED) +target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + +# Suppress nvcc warnings on MLX headers. +target_compile_options(mlx PRIVATE $<$:-Xcudafe + --diag_suppress=997>) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp new file mode 100644 index 000000000..203534e21 --- /dev/null +++ b/mlx/backend/cuda/allocator.cpp @@ -0,0 +1,154 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/worker.h" + +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +CudaAllocator::CudaAllocator() { + // TODO: Set memory limit for multi-device. + size_t free, total; + CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); + memory_limit_ = total * 0.8; +} + +Buffer CudaAllocator::malloc(size_t size) { + // TODO: Check memory limit. + auto* buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error( + fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + std::lock_guard lock(mutex_); + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{buf}; +} + +void CudaAllocator::free(Buffer buffer) { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return; + } + + // If free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([buffer]() { allocator().free(buffer); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + size_t size = buf->size; + cudaFree(buf->data); + delete buf; + std::lock_guard lock(mutex_); + active_memory_ -= size; +} + +size_t CudaAllocator::size(Buffer buffer) const { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return 0; + } + return buf->size; +} + +void CudaAllocator::register_this_thread() { + std::lock_guard lock(worker_mutex_); + allowed_threads_.insert(std::this_thread::get_id()); +} + +size_t CudaAllocator::get_active_memory() const { + return active_memory_; +} + +size_t CudaAllocator::get_peak_memory() const { + return peak_memory_; +} + +void CudaAllocator::reset_peak_memory() { + std::lock_guard lock(mutex_); + peak_memory_ = 0; +} + +size_t CudaAllocator::get_memory_limit() { + return memory_limit_; +} + +size_t CudaAllocator::set_memory_limit(size_t limit) { + std::lock_guard lock(mutex_); + std::swap(limit, memory_limit_); + return limit; +} + +CudaAllocator& allocator() { + // By creating the |allocator_| on heap, the destructor of CudaAllocator + // will not be called on exit and buffers in the cache will be leaked. This + // can save some time at program exit. + static CudaAllocator* allocator_ = new CudaAllocator; + return *allocator_; +} + +} // namespace cu + +namespace allocator { + +Allocator& allocator() { + return cu::allocator(); +} + +void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } + return static_cast(ptr_)->data; +} + +} // namespace allocator + +size_t get_active_memory() { + return cu::allocator().get_active_memory(); +} +size_t get_peak_memory() { + return cu::allocator().get_peak_memory(); +} +void reset_peak_memory() { + return cu::allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return cu::allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return cu::allocator().get_memory_limit(); +} + +// TODO: Implement buffer cache. +size_t get_cache_memory() { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; +} +size_t set_wired_limit(size_t) { + return 0; +} +void clear_cache() {} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h new file mode 100644 index 000000000..6c418ee7e --- /dev/null +++ b/mlx/backend/cuda/allocator.h @@ -0,0 +1,58 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +class Worker; + +using allocator::Buffer; + +// Stores cuda-managed unified memory. +struct CudaBuffer { + void* data; + size_t size; +}; + +class CudaAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // Register current thread as safe to free buffers. + // In cuda freeing a buffer implicitly synchronizes stream, and for threads + // that may be waited by gpu stream (for example cpu stream threads), freeing + // buffers there would result in dead lock. + void register_this_thread(); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + + private: + CudaAllocator(); + friend CudaAllocator& allocator(); + + std::mutex worker_mutex_; + std::unique_ptr worker_; + std::set allowed_threads_; + + std::mutex mutex_; + size_t memory_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; +}; + +CudaAllocator& allocator(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/copy.cpp b/mlx/backend/cuda/copy.cpp new file mode 100644 index 000000000..d0413d989 --- /dev/null +++ b/mlx/backend/cuda/copy.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& data_shape, + const Strides& strides_in_pre, + const Strides& strides_out_pre, + int64_t inp_offset, + int64_t out_offset, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_i_offset /* = std::nullopt */, + const std::optional& dynamic_o_offset /* = std::nullopt */) { + throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); +} + +void fill_gpu(const array& val, array& out, const Stream& s) { + throw std::runtime_error("fill_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp new file mode 100644 index 000000000..a28ffa35e --- /dev/null +++ b/mlx/backend/cuda/device.cpp @@ -0,0 +1,117 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/metal/metal.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} + +void DeviceStream::synchronize() { + cudaStreamSynchronize(stream_); +} + +cudaStream_t DeviceStream::schedule_cuda_stream() { + // TODO: Return a stream that maximizes parallelism. + return stream_; +} + +cudaStream_t DeviceStream::last_cuda_stream() { + return stream_; +} + +CommandEncoder& DeviceStream::get_encoder() { + if (!encoder_) { + encoder_ = std::make_unique(*this); + } + return *encoder_; +} + +Device::Device(int device) : device_(device) { + // Validate the requirements of device. + int attr = 0; + cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_); + if (attr != 1) { + throw std::runtime_error(fmt::format( + "Device {} does not support synchronization in managed memory.", + device_)); + } +} + +void Device::make_current() { + // We need to set/get current CUDA device very frequently, cache it to reduce + // actual calls of CUDA APIs. This function assumes single-thread in host. + static int current = 0; + if (current != device_) { + CHECK_CUDA_ERROR(cudaSetDevice(device_)); + current = device_; + } +} + +DeviceStream& Device::get_stream(Stream s) { + auto it = streams_.find(s.index); + if (it == streams_.end()) { + it = streams_.try_emplace(s.index, *this).first; + } + return it->second; +} + +CommandEncoder::CommandEncoder(DeviceStream& s) + : device_(s.device()), stream_(s) {} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_.add_task(std::move(task)); +} + +void CommandEncoder::end_encoding() { + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + + // There is no kernel running, run completion handlers immediately. + if (!has_gpu_work_) { + worker_.consume_in_this_thread(); + return; + } + has_gpu_work_ = false; + + // Put completion handlers in a batch. + worker_.end_batch(); + + // Signaling kernel completion is expensive, delay until enough batches. + // TODO: This number is arbitrarily picked, profile for a better stragety. + if (worker_.uncommited_batches() > 8) { + commit(); + } +} + +void CommandEncoder::commit() { + worker_.commit(stream_.last_cuda_stream()); +} + +Device& device(mlx::core::Device device) { + static std::unordered_map devices; + auto it = devices.find(device.index); + if (it == devices.end()) { + it = devices.try_emplace(device.index, device.index).first; + } + return it->second; +} + +DeviceStream& get_stream(Stream s) { + return device(s.device).get_stream(s); +} + +CommandEncoder& get_command_encoder(Stream s) { + return get_stream(s).get_encoder(); +} + +} // namespace cu + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h new file mode 100644 index 000000000..a65a87d54 --- /dev/null +++ b/mlx/backend/cuda/device.h @@ -0,0 +1,131 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/stream.h" + +#include + +#include + +namespace mlx::core::cu { + +class Device; +class CommandEncoder; + +class DeviceStream { + public: + explicit DeviceStream(Device& device); + + DeviceStream(const DeviceStream&) = delete; + DeviceStream& operator=(const DeviceStream&) = delete; + + // Wait until kernels in the stream complete. + void synchronize(); + + // Return a cuda stream for launching kernels. + cudaStream_t schedule_cuda_stream(); + + // Return the last cuda stream used. + cudaStream_t last_cuda_stream(); + + CommandEncoder& get_encoder(); + + Device& device() { + return device_; + } + + private: + Device& device_; + CudaStream stream_; + std::unique_ptr encoder_; +}; + +class Device { + public: + explicit Device(int device); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current cuda device, required by some cuda calls. + void make_current(); + + DeviceStream& get_stream(Stream s); + + int cuda_device() const { + return device_; + } + + private: + int device_; + std::unordered_map streams_; +}; + +class CommandEncoder { + public: + explicit CommandEncoder(DeviceStream& stream); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + void set_input_array(const array& arr) {} + void set_output_array(const array& arr) {} + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void end_encoding(); + void commit(); + + // Schedule a cuda stream for |fun| to launch kernels, and check error + // afterwards. + template + void launch_kernel(F&& fun) { + launch_kernel(stream_.schedule_cuda_stream(), std::forward(fun)); + } + + template + void launch_kernel(cudaStream_t stream, F&& fun) { + device_.make_current(); + fun(stream); + check_cuda_error("kernel launch", cudaGetLastError()); + has_gpu_work_ = true; + } + + Device& device() { + return device_; + } + + DeviceStream& stream() { + return stream_; + } + + bool has_gpu_work() const { + return has_gpu_work_; + } + + private: + Device& device_; + DeviceStream& stream_; + Worker worker_; + bool has_gpu_work_{false}; + std::vector> temporaries_; +}; + +Device& device(mlx::core::Device device); +DeviceStream& get_stream(Stream s); +CommandEncoder& get_command_encoder(Stream s); + +// Return an execution policy that does not sync for result. +// Note that not all thrust APIs support async policy, confirm before using. +inline auto thrust_policy(cudaStream_t stream) { + // TODO: Connect thrust's custom allocator with mlx's allocator. + return thrust::cuda::par_nosync.on(stream); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/dtype_utils.cuh b/mlx/backend/cuda/dtype_utils.cuh new file mode 100644 index 000000000..9b7f8ba65 --- /dev/null +++ b/mlx/backend/cuda/dtype_utils.cuh @@ -0,0 +1,35 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +// Maps CPU types to CUDA types. +template +struct CTypeToCudaType { + using type = T; +}; + +template <> +struct CTypeToCudaType { + using type = __half; +}; + +template <> +struct CTypeToCudaType { + using type = __nv_bfloat16; +}; + +template <> +struct CTypeToCudaType { + using type = cuComplex; +}; + +template +using cuda_type_t = typename CTypeToCudaType::type; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp new file mode 100644 index 000000000..b309ad60e --- /dev/null +++ b/mlx/backend/cuda/eval.cpp @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream s) { + // Force initalization of cuda, so cuda runtime get destroyed at last. + cudaFree(nullptr); + // Ensure the static stream objects get created. + cu::get_command_encoder(s); + // The main thread is safe to free buffers. + cu::allocator().register_this_thread(); +} + +void eval(array& arr) { + nvtx3::scoped_range r("gpu::eval"); + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + + auto& encoder = cu::get_command_encoder(arr.primitive().stream()); + if (encoder.has_gpu_work()) { + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + } + encoder.end_encoding(); +} + +void finalize(Stream s) { + nvtx3::scoped_range r("gpu::finalize"); + cu::get_command_encoder(s).commit(); +} + +void synchronize(Stream s) { + nvtx3::scoped_range r("gpu::synchronize"); + cu::get_stream(s).synchronize(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu new file mode 100644 index 000000000..a487f45b4 --- /dev/null +++ b/mlx/backend/cuda/event.cu @@ -0,0 +1,265 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/event.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace cu { + +/////////////////////////////////////////////////////////////////////////////// +// CudaEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +// Cuda event managed with RAII. +class CudaEventHandle { + public: + CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventCreateWithFlags( + &event_, cudaEventDisableTiming | cudaEventBlockingSync)); + } + + ~CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventDestroy(event_)); + } + + CudaEventHandle(const CudaEventHandle&) = delete; + CudaEventHandle& operator=(const CudaEventHandle&) = delete; + + operator cudaEvent_t() const { + return event_; + } + + private: + cudaEvent_t event_; +}; + +CudaEvent::CudaEvent() : event_(std::make_shared()) {} + +void CudaEvent::wait() { + nvtx3::scoped_range r("cu::CudaEvent::wait"); + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaEventSynchronize(*event_); +} + +void CudaEvent::wait(cudaStream_t stream) { + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaStreamWaitEvent(stream, *event_); +} + +void CudaEvent::wait(Stream s) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this]() mutable { wait(); }); + } else { + wait(cu::get_stream(s).last_cuda_stream()); + } +} + +void CudaEvent::record(cudaStream_t stream) { + cudaEventRecord(*event_, stream); + recorded_ = true; +} + +void CudaEvent::record(Stream s) { + if (s.device == mlx::core::Device::cpu) { + throw std::runtime_error("CudaEvent can not wait on cpu stream."); + } else { + record(cu::get_stream(s).last_cuda_stream()); + } +} + +bool CudaEvent::completed() const { + return cudaEventQuery(*event_) == cudaSuccess; +} + +/////////////////////////////////////////////////////////////////////////////// +// SharedEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { + uint64_t current; + while ((current = ac->load()) < value) { + ac->wait(current); + } +} + +__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { + ac->store(value); + ac->notify_all(); +} + +__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_wait(ac, value); +} + +__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_signal(ac, value); +} + +} // namespace + +SharedEvent::SharedEvent() { + // Allocate cuda::atomic on managed memory. + allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); + Atomic* ac = static_cast(buffer.raw_ptr()); + new (ac) Atomic(0); + ac_ = std::shared_ptr(ac, [buffer](Atomic* ptr) { + ptr->~Atomic(); + allocator::free(buffer); + }); +} + +void SharedEvent::wait(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait"); + event_wait(ac_.get(), value); +} + +void SharedEvent::wait(cudaStream_t stream, uint64_t value) { + event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::wait(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { wait(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +void SharedEvent::signal(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal"); + event_signal(ac_.get(), value); +} + +void SharedEvent::signal(cudaStream_t stream, uint64_t value) { + event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::signal(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { signal(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +bool SharedEvent::is_signaled(uint64_t value) const { + nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); + return ac_->load() >= value; +} + +uint64_t SharedEvent::value() const { + nvtx3::scoped_range r("cu::SharedEvent::value"); + return ac_->load(); +} + +} // namespace cu + +/////////////////////////////////////////////////////////////////////////////// +// Event implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct EventImpl { + // CudaEvent is preferred when possible because it is fast, however we have + // to fallback to SharedEvent in following cases: + // 1. the event is used to wait/signal a cpu stream; + // 2. signal value other than 1 has been specified. + std::unique_ptr cuda; + std::unique_ptr shared; + + bool is_created() const { + return cuda || shared; + } + + void ensure_created(Stream s, uint64_t signal_value) { + if (is_created()) { + return; + } + if (s.device == mlx::core::Device::cpu || signal_value > 1) { + nvtx3::mark("Using slow SharedEvent"); + shared = std::make_unique(); + } else { + cuda = std::make_unique(); + } + } +}; + +} // namespace + +Event::Event(Stream s) : stream_(s) { + event_ = std::shared_ptr( + new EventImpl(), [](void* ptr) { delete static_cast(ptr); }); +} + +void Event::wait() { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(); + } else { + event->shared->wait(value()); + } +} + +void Event::wait(Stream s) { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(s); + } else { + event->shared->wait(s, value()); + } +} + +void Event::signal(Stream s) { + auto* event = static_cast(event_.get()); + event->ensure_created(s, value()); + if (event->cuda) { + assert(value() == 1); + event->cuda->record(s); + } else { + event->shared->signal(s, value()); + } +} + +bool Event::is_signaled() const { + auto* event = static_cast(event_.get()); + if (!event->is_created()) { + return false; + } + if (event->cuda) { + assert(value() == 1); + return event->cuda->recorded() && event->cuda->completed(); + } else { + return event->shared->is_signaled(value()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h new file mode 100644 index 000000000..4b56e2e3b --- /dev/null +++ b/mlx/backend/cuda/event.h @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/stream.h" + +#include +#include + +#include + +namespace mlx::core::cu { + +class CudaEventHandle; + +// Wrapper of native cuda event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class CudaEvent { + public: + CudaEvent(); + + void wait(); + void wait(cudaStream_t stream); + void wait(Stream s); + void record(cudaStream_t stream); + void record(Stream s); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + bool recorded() const { + return recorded_; + } + + private: + bool recorded_{false}; + std::shared_ptr event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// CudaEvent so the latter should always be preferred when possible. +class SharedEvent { + public: + using Atomic = cuda::atomic; + + SharedEvent(); + + void wait(uint64_t value); + void wait(cudaStream_t stream, uint64_t value); + void wait(Stream s, uint64_t value); + void signal(uint64_t value); + void signal(cudaStream_t stream, uint64_t value); + void signal(Stream s, uint64_t value); + bool is_signaled(uint64_t value) const; + uint64_t value() const; + + const std::shared_ptr& atomic() const { + return ac_; + } + + private: + std::shared_ptr ac_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/fence.cu b/mlx/backend/cuda/fence.cu new file mode 100644 index 000000000..091b252c1 --- /dev/null +++ b/mlx/backend/cuda/fence.cu @@ -0,0 +1,70 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/fence.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace { + +__host__ __device__ void busy_wait(cuda::atomic* ac, uint64_t value) { + while (true) { + // In theory the atomic_thread_fence is not needed, but for CUDA 11 without + // it the load() may never return new value. + cuda::atomic_thread_fence(cuda::memory_order_seq_cst); + uint64_t current = ac->load(); + if (current >= value) { + break; + } + } +} + +__global__ void busy_wait_kernel(cuda::atomic* ac, uint64_t value) { + busy_wait(ac, value); +} + +} // namespace + +struct FenceImpl { + uint32_t count; + cu::SharedEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + // We can't use SharedEvent::wait because it could hang in CUDA 11, see also: + // https://github.com/ml-explore/mlx/issues/2137 + const auto& ac = fence->event.atomic(); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [ac, count = fence->count]() { + nvtx3::scoped_range r("Fence::wait()"); + busy_wait(ac.get(), count); + }); + } else { + nvtx3::scoped_range r("Fence::wait(s)"); + auto& encoder = cu::get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) { + busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count); + }); + encoder.add_completed_handler([ac]() {}); + encoder.end_encoding(); + } +} + +void Fence::update(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/arange.cuh b/mlx/backend/cuda/kernels/arange.cuh new file mode 100644 index 000000000..53c261e34 --- /dev/null +++ b/mlx/backend/cuda/kernels/arange.cuh @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::cu { + +template +struct Arange { + const T start; + const T step; + + __device__ T operator()(uint32_t i) const { + return start + i * step; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh new file mode 100644 index 000000000..931c55ff7 --- /dev/null +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -0,0 +1,107 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// Missing C++ operator overrides for CUDA 7. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +#define MLX_DEFINE_BF16_OP(OP) \ + __forceinline__ __device__ __nv_bfloat16 operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +#define MLX_DEFINE_BF16_CMP(OP) \ + __forceinline__ __device__ bool operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +MLX_DEFINE_BF16_OP(+) +MLX_DEFINE_BF16_OP(-) +MLX_DEFINE_BF16_OP(*) +MLX_DEFINE_BF16_OP(/) +MLX_DEFINE_BF16_CMP(>) +MLX_DEFINE_BF16_CMP(<) +MLX_DEFINE_BF16_CMP(>=) +MLX_DEFINE_BF16_CMP(<=) + +#undef MLX_DEFINE_BF16_OP +#undef MLX_DEFINE_BF16_CMP + +#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +/////////////////////////////////////////////////////////////////////////////// +// Additional C++ operator overrides between half types and native types. +/////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool is_integral_except = + cuda::std::is_integral_v && !cuda::std::is_same_v; + +template +constexpr bool is_arithmetic_except = + cuda::std::is_arithmetic_v && !cuda::std::is_same_v; + +#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(HALF x, T y) { \ + return FLOAT2HALF(HALF2FLOAT(x) OP static_cast(y)); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(T x, HALF y) { \ + return FLOAT2HALF(static_cast(x) OP HALF2FLOAT(y)); \ + } + +#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(HALF x, T y) { \ + return HALF2FLOAT(x) OP static_cast(y); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(T x, HALF y) { \ + return static_cast(y) OP HALF2FLOAT(x); \ + } + +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /) +MLX_DEFINE_HALF_CMP(__half, __half2float, <) +MLX_DEFINE_HALF_CMP(__half, __half2float, >) +MLX_DEFINE_HALF_CMP(__half, __half2float, <=) +MLX_DEFINE_HALF_CMP(__half, __half2float, >=) +MLX_DEFINE_HALF_CMP(__half, __half2float, ==) +MLX_DEFINE_HALF_CMP(__half, __half2float, !=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=) + +#undef MLX_DEFINE_HALF_OP +#undef MLX_DEFINE_HALF_CMP + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu new file mode 100644 index 000000000..dc6edf606 --- /dev/null +++ b/mlx/backend/cuda/primitives.cu @@ -0,0 +1,163 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/dtype_utils.cuh" +#include "mlx/backend/cuda/kernels/arange.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/distributed/primitives.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Arange::eval_gpu"); + assert(inputs.size() == 0); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_output_array(out); + encoder.launch_kernel([&, this](cudaStream_t stream) { + MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); + }); + }); +} + +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +NO_GPU(Abs) +NO_GPU(Add) +NO_GPU(AddMM) +NO_GPU(ArcCos) +NO_GPU(ArcCosh) +NO_GPU(ArcSin) +NO_GPU(ArcSinh) +NO_GPU(ArcTan) +NO_GPU(ArcTan2) +NO_GPU(ArcTanh) +NO_GPU(ArgPartition) +NO_GPU(ArgReduce) +NO_GPU(ArgSort) +NO_GPU(BitwiseBinary) +NO_GPU(BitwiseInvert) +NO_GPU(BlockMaskedMM) +NO_GPU(Ceil) +NO_GPU_MULTI(Compiled) +NO_GPU(Conjugate) +NO_GPU(Convolution) +NO_GPU(Cos) +NO_GPU(Cosh) +NO_GPU(Divide) +NO_GPU_MULTI(DivMod) +NO_GPU(DynamicSlice) +NO_GPU(DynamicSliceUpdate) +NO_GPU(Remainder) +NO_GPU(Equal) +NO_GPU(Erf) +NO_GPU(ErfInv) +NO_GPU(Exp) +NO_GPU(Expm1) +NO_GPU(FFT) +NO_GPU(Floor) +NO_GPU(Gather) +NO_GPU(GatherAxis) +NO_GPU(GatherMM) +NO_GPU(GatherQMM) +NO_GPU(Greater) +NO_GPU(GreaterEqual) +NO_GPU(Hadamard) +NO_GPU(Imag) +NO_GPU(Less) +NO_GPU(LessEqual) +NO_GPU(Load) +NO_GPU(Log) +NO_GPU(Log1p) +NO_GPU(LogicalNot) +NO_GPU(LogicalAnd) +NO_GPU(LogicalOr) +NO_GPU(LogAddExp) +NO_GPU(LogSumExp) +NO_GPU_MULTI(LUF) +NO_GPU(Matmul) +NO_GPU(Maximum) +NO_GPU(Minimum) +NO_GPU(Multiply) +NO_GPU(Negative) +NO_GPU(NotEqual) +NO_GPU(Partition) +NO_GPU(Power) +NO_GPU_MULTI(QRF) +NO_GPU(QuantizedMatmul) +NO_GPU(RandomBits) +NO_GPU(Real) +NO_GPU(Reduce) +NO_GPU(Round) +NO_GPU(Scan) +NO_GPU(Scatter) +NO_GPU(ScatterAxis) +NO_GPU(Select) +NO_GPU(Sigmoid) +NO_GPU(Sign) +NO_GPU(Sin) +NO_GPU(Sinh) +NO_GPU(SliceUpdate) +NO_GPU(Softmax) +NO_GPU(Sort) +NO_GPU(Square) +NO_GPU(Sqrt) +NO_GPU(Subtract) +NO_GPU_MULTI(SVD) +NO_GPU(Tan) +NO_GPU(Tanh) +NO_GPU(Inverse) +NO_GPU(Cholesky) +NO_GPU_MULTI(Eigh) + +namespace fast { +NO_GPU_MULTI(LayerNorm) +NO_GPU_MULTI(LayerNormVJP) +NO_GPU_MULTI(RMSNorm) +NO_GPU_MULTI(RMSNormVJP) +NO_GPU_MULTI(RoPE) +NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(CustomKernel) +} // namespace fast + +namespace distributed { +NO_GPU_MULTI(AllReduce) +NO_GPU_MULTI(AllGather) +NO_GPU_MULTI(Send) +NO_GPU_MULTI(Recv) +} // namespace distributed + +} // namespace mlx::core diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp new file mode 100644 index 000000000..bfa742c74 --- /dev/null +++ b/mlx/backend/cuda/slicing.cpp @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s) { + throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp new file mode 100644 index 000000000..2a11a518e --- /dev/null +++ b/mlx/backend/cuda/utils.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); +} + +CudaStream::~CudaStream() { + CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); +} + +void check_cuda_error(const char* name, cudaError_t err) { + if (err != cudaSuccess) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, cudaGetErrorString(err))); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h new file mode 100644 index 000000000..58d508765 --- /dev/null +++ b/mlx/backend/cuda/utils.h @@ -0,0 +1,36 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +namespace cu { +class Device; +} + +// Cuda stream managed with RAII. +class CudaStream { + public: + explicit CudaStream(cu::Device& device); + ~CudaStream(); + + CudaStream(const CudaStream&) = delete; + CudaStream& operator=(const CudaStream&) = delete; + + operator cudaStream_t() const { + return stream_; + } + + private: + cudaStream_t stream_; +}; + +// Throw exception if the cuda API does not succeed. +void check_cuda_error(const char* name, cudaError_t err); + +// The macro version that prints the command that failed. +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) + +} // namespace mlx::core diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp new file mode 100644 index 000000000..64b5c7679 --- /dev/null +++ b/mlx/backend/cuda/worker.cpp @@ -0,0 +1,90 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +Worker::Worker() + : signal_stream_(device(mlx::core::Device::gpu)), + worker_(&Worker::thread_fn, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(worker_mutex_); + stop_ = true; + } + worker_event_.signal(batch_ + 1); + worker_.join(); +} + +void Worker::add_task(std::function task) { + pending_tasks_.push_back(std::move(task)); +} + +void Worker::consume_in_this_thread() { + for (auto& task : pending_tasks_) { + task(); + } + pending_tasks_.clear(); +} + +void Worker::end_batch() { + batch_++; + { + std::lock_guard lock(worker_mutex_); + worker_tasks_[batch_] = std::move(pending_tasks_); + } + uncommited_batches_++; +} + +void Worker::commit() { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + worker_event_.signal(batch_); +} + +void Worker::commit(cudaStream_t stream) { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + // Signal the |worker_event_| in |signal_stream_| after the kernels in + // |stream_| finish running. + signal_event_.record(stream); + signal_event_.wait(signal_stream_); + worker_event_.signal(signal_stream_, batch_); +} + +void Worker::thread_fn() { + // The worker thread is safe to free buffers. + allocator().register_this_thread(); + + while (!stop_) { + uint64_t batch = worker_event_.value(); + Tasks tasks; + { + std::lock_guard lock(worker_mutex_); + // Move tasks in signaled batches. + auto end = worker_tasks_.upper_bound(batch); + for (auto it = worker_tasks_.begin(); it != end; ++it) { + if (tasks.empty()) { + tasks = std::move(it->second); + } else { + std::move( + it->second.begin(), it->second.end(), std::back_inserter(tasks)); + } + } + worker_tasks_.erase(worker_tasks_.begin(), end); + } + for (auto& task : tasks) { + task(); + } + worker_event_.wait(batch + 1); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h new file mode 100644 index 000000000..d28e22e95 --- /dev/null +++ b/mlx/backend/cuda/worker.h @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +// Run tasks in worker thread, synchronized with cuda stream. +class Worker { + public: + Worker(); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or commited. + void add_task(std::function task); + + // Run pending tasks immediately in current thread. + void consume_in_this_thread(); + + // Put pending tasks in a batch. + void end_batch(); + + // Inform worker thread to run current batches now. + void commit(); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(cudaStream_t stream); + + // Return how many batches have been added but not committed yet. + size_t uncommited_batches() const { + return uncommited_batches_; + } + + private: + void thread_fn(); + + uint64_t batch_{0}; + size_t uncommited_batches_{0}; + + // Cuda stream and event for signaling kernel completion. + CudaStream signal_stream_; + CudaEvent signal_event_; + + // Worker thread. + SharedEvent worker_event_; + std::thread worker_; + std::mutex worker_mutex_; + bool stop_{false}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; +}; + +} // namespace mlx::core::cu diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cf0ba3d5d..cb174865d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,7 +9,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) -if(MLX_BUILD_METAL) +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) set(METAL_TEST_SOURCES gpu_tests.cpp) endif()