CUDA backend: backbone (#2075)

This commit is contained in:
Cheng 2025-05-07 13:26:46 +09:00 committed by GitHub
parent 5a1a5d5ed1
commit 0cae0bdac8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1582 additions and 2 deletions

View File

@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)

View File

@ -47,10 +47,18 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif()

View File

@ -0,0 +1,57 @@
# Filename rules in cuda backend:
#
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
# * Device-only kernel code should be put in kernels/ subdir.
# * Files in kernels/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
# Enable defining device lambda functions.
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"75;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
# Use fixed version of CCCL.
FetchContent_Declare(
cccl
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
nvtx3
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
GIT_TAG v3.1.1
GIT_SHALLOW TRUE
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nvtx3)
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
# Make cuda runtime APIs available in non-cuda files.
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

View File

@ -0,0 +1,154 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <cassert>
namespace mlx::core {
namespace cu {
CudaAllocator::CudaAllocator() {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
}
Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit.
auto* buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
std::lock_guard lock(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{buf};
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
// If free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
}
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
}
size_t CudaAllocator::size(Buffer buffer) const {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
size_t CudaAllocator::get_active_memory() const {
return active_memory_;
}
size_t CudaAllocator::get_peak_memory() const {
return peak_memory_;
}
void CudaAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}
size_t CudaAllocator::get_memory_limit() {
return memory_limit_;
}
size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static CudaAllocator* allocator_ = new CudaAllocator;
return *allocator_;
}
} // namespace cu
namespace allocator {
Allocator& allocator() {
return cu::allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
}
} // namespace allocator
size_t get_active_memory() {
return cu::allocator().get_active_memory();
}
size_t get_peak_memory() {
return cu::allocator().get_peak_memory();
}
void reset_peak_memory() {
return cu::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return cu::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
// TODO: Implement buffer cache.
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@ -0,0 +1,58 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
private:
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
};
CudaAllocator& allocator();
} // namespace mlx::core::cu

26
mlx/backend/cuda/copy.cpp Normal file
View File

@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
}
void fill_gpu(const array& val, array& out, const Stream& s) {
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

117
mlx/backend/cuda/device.cpp Normal file
View File

@ -0,0 +1,117 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/metal/metal.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
cudaStreamSynchronize(stream_);
}
cudaStream_t DeviceStream::schedule_cuda_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
cudaStream_t DeviceStream::last_cuda_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) {
// Validate the requirements of device.
int attr = 0;
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
if (attr != 1) {
throw std::runtime_error(fmt::format(
"Device {} does not support synchronization in managed memory.",
device_));
}
}
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
}
}
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it == streams_.end()) {
it = streams_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(DeviceStream& s)
: device_(s.device()), stream_(s) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::end_encoding() {
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
// There is no kernel running, run completion handlers immediately.
if (!has_gpu_work_) {
worker_.consume_in_this_thread();
return;
}
has_gpu_work_ = false;
// Put completion handlers in a batch.
worker_.end_batch();
// Signaling kernel completion is expensive, delay until enough batches.
// TODO: This number is arbitrarily picked, profile for a better stragety.
if (worker_.uncommited_batches() > 8) {
commit();
}
}
void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream());
}
Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);
if (it == devices.end()) {
it = devices.try_emplace(device.index, device.index).first;
}
return it->second;
}
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace cu
} // namespace mlx::core

131
mlx/backend/cuda/device.h Normal file
View File

@ -0,0 +1,131 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <thrust/execution_policy.h>
#include <unordered_map>
namespace mlx::core::cu {
class Device;
class CommandEncoder;
class DeviceStream {
public:
explicit DeviceStream(Device& device);
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
// Wait until kernels in the stream complete.
void synchronize();
// Return a cuda stream for launching kernels.
cudaStream_t schedule_cuda_stream();
// Return the last cuda stream used.
cudaStream_t last_cuda_stream();
CommandEncoder& get_encoder();
Device& device() {
return device_;
}
private:
Device& device_;
CudaStream stream_;
std::unique_ptr<CommandEncoder> encoder_;
};
class Device {
public:
explicit Device(int device);
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
void make_current();
DeviceStream& get_stream(Stream s);
int cuda_device() const {
return device_;
}
private:
int device_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result.
// Note that not all thrust APIs support async policy, confirm before using.
inline auto thrust_policy(cudaStream_t stream) {
// TODO: Connect thrust's custom allocator with mlx's allocator.
return thrust::cuda::par_nosync.on(stream);
}
} // namespace mlx::core::cu

View File

@ -0,0 +1,35 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core {
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
} // namespace mlx::core

68
mlx/backend/cuda/eval.cpp Normal file
View File

@ -0,0 +1,68 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last.
cudaFree(nullptr);
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
}
void eval(array& arr) {
nvtx3::scoped_range r("gpu::eval");
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
if (encoder.has_gpu_work()) {
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
}
encoder.end_encoding();
}
void finalize(Stream s) {
nvtx3::scoped_range r("gpu::finalize");
cu::get_command_encoder(s).commit();
}
void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize();
}
} // namespace mlx::core::gpu

265
mlx/backend/cuda/event.cu Normal file
View File

@ -0,0 +1,265 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
///////////////////////////////////////////////////////////////////////////////
// CudaEvent implementations
///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII.
class CudaEventHandle {
public:
CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
}
~CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
}
CudaEventHandle(const CudaEventHandle&) = delete;
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
}
private:
cudaEvent_t event_;
};
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
}
void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
wait(cu::get_stream(s).last_cuda_stream());
}
}
void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream);
recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
record(cu::get_stream(s).last_cuda_stream());
}
}
bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current;
while ((current = ac->load()) < value) {
ac->wait(current);
}
}
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
ac->store(value);
ac->notify_all();
}
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value);
}
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
} // namespace
SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
ptr->~Atomic();
allocator::free(buffer);
});
}
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value);
}
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { wait(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value);
}
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { signal(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value;
}
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load();
}
} // namespace cu
///////////////////////////////////////////////////////////////////////////////
// Event implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases:
// 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared;
bool is_created() const {
return cuda || shared;
}
void ensure_created(Stream s, uint64_t signal_value) {
if (is_created()) {
return;
}
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent");
shared = std::make_unique<cu::SharedEvent>();
} else {
cuda = std::make_unique<cu::CudaEvent>();
}
}
};
} // namespace
Event::Event(Stream s) : stream_(s) {
event_ = std::shared_ptr<void>(
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
}
void Event::wait() {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait();
} else {
event->shared->wait(value());
}
}
void Event::wait(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait(s);
} else {
event->shared->wait(s, value());
}
}
void Event::signal(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
event->ensure_created(s, value());
if (event->cuda) {
assert(value() == 1);
event->cuda->record(s);
} else {
event->shared->signal(s, value());
}
}
bool Event::is_signaled() const {
auto* event = static_cast<EventImpl*>(event_.get());
if (!event->is_created()) {
return false;
}
if (event->cuda) {
assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed();
} else {
return event->shared->is_signaled(value());
}
}
} // namespace mlx::core

66
mlx/backend/cuda/event.h Normal file
View File

@ -0,0 +1,66 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/stream.h"
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <memory>
namespace mlx::core::cu {
class CudaEventHandle;
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
CudaEvent();
void wait();
void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called.
bool completed() const;
bool recorded() const {
return recorded_;
}
private:
bool recorded_{false};
std::shared_ptr<CudaEventHandle> event_;
};
// Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible.
class SharedEvent {
public:
using Atomic = cuda::atomic<uint64_t>;
SharedEvent();
void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value);
void wait(Stream s, uint64_t value);
void signal(uint64_t value);
void signal(cudaStream_t stream, uint64_t value);
void signal(Stream s, uint64_t value);
bool is_signaled(uint64_t value) const;
uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private:
std::shared_ptr<Atomic> ac_;
};
} // namespace mlx::core::cu

70
mlx/backend/cuda/fence.cu Normal file
View File

@ -0,0 +1,70 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
while (true) {
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
// it the load() may never return new value.
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
uint64_t current = ac->load();
if (current >= value) {
break;
}
}
}
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
busy_wait(ac, value);
}
} // namespace
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
// https://github.com/ml-explore/mlx/issues/2137
const auto& ac = fence->event.atomic();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [ac, count = fence->count]() {
nvtx3::scoped_range r("Fence::wait()");
busy_wait(ac.get(), count);
});
} else {
nvtx3::scoped_range r("Fence::wait(s)");
auto& encoder = cu::get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
});
encoder.add_completed_handler([ac]() {});
encoder.end_encoding();
}
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,107 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_fp16.h>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// Missing C++ operator overrides for CUDA 7.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_BF16_OP(OP) \
__forceinline__ __device__ __nv_bfloat16 operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
#define MLX_DEFINE_BF16_CMP(OP) \
__forceinline__ __device__ bool operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
MLX_DEFINE_BF16_OP(+)
MLX_DEFINE_BF16_OP(-)
MLX_DEFINE_BF16_OP(*)
MLX_DEFINE_BF16_OP(/)
MLX_DEFINE_BF16_CMP(>)
MLX_DEFINE_BF16_CMP(<)
MLX_DEFINE_BF16_CMP(>=)
MLX_DEFINE_BF16_CMP(<=)
#undef MLX_DEFINE_BF16_OP
#undef MLX_DEFINE_BF16_CMP
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types.
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U>
constexpr bool is_integral_except =
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
template <typename T, typename U>
constexpr bool is_arithmetic_except =
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
}
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
return HALF2FLOAT(x) OP static_cast<float>(y); \
} \
template < \
typename T, \
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
return static_cast<float>(y) OP HALF2FLOAT(x); \
}
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
#undef MLX_DEFINE_HALF_OP
#undef MLX_DEFINE_HALF_CMP
} // namespace mlx::core::cu

View File

@ -0,0 +1,163 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/dtype_utils.cuh"
#include "mlx/backend/cuda/kernels/arange.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) {
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
});
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Conjugate)
NO_GPU(Convolution)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Real)
NO_GPU(Reduce)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
int axis,
const Stream& s) {
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

View File

@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/device.h"
#include <fmt/format.h>
namespace mlx::core {
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) {
throw std::runtime_error(
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
}
}
} // namespace mlx::core

36
mlx/backend/cuda/utils.h Normal file
View File

@ -0,0 +1,36 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_runtime.h>
namespace mlx::core {
namespace cu {
class Device;
}
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Throw exception if the cuda API does not succeed.
void check_cuda_error(const char* name, cudaError_t err);
// The macro version that prints the command that failed.
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
} // namespace mlx::core

View File

@ -0,0 +1,90 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {
{
std::lock_guard lock(worker_mutex_);
stop_ = true;
}
worker_event_.signal(batch_ + 1);
worker_.join();
}
void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task));
}
void Worker::consume_in_this_thread() {
for (auto& task : pending_tasks_) {
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{
std::lock_guard lock(worker_mutex_);
worker_tasks_[batch_] = std::move(pending_tasks_);
}
uncommited_batches_++;
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
}
void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
// Signal the |worker_event_| in |signal_stream_| after the kernels in
// |stream_| finish running.
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_);
}
void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) {
uint64_t batch = worker_event_.value();
Tasks tasks;
{
std::lock_guard lock(worker_mutex_);
// Move tasks in signaled batches.
auto end = worker_tasks_.upper_bound(batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) {
tasks = std::move(it->second);
} else {
std::move(
it->second.begin(), it->second.end(), std::back_inserter(tasks));
}
}
worker_tasks_.erase(worker_tasks_.begin(), end);
}
for (auto& task : tasks) {
task();
}
worker_event_.wait(batch + 1);
}
}
} // namespace mlx::core::cu

68
mlx/backend/cuda/worker.h Normal file
View File

@ -0,0 +1,68 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <functional>
#include <map>
#include <mutex>
#include <thread>
namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
~Worker();
Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;
// Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream|
// finish running.
void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private:
void thread_fn();
uint64_t batch_{0};
size_t uncommited_batches_{0};
// Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_;
CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to
// |worker_tasks_| when end_batch() is called.
using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_;
};
} // namespace mlx::core::cu

View File

@ -9,7 +9,7 @@ FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
if(MLX_BUILD_METAL)
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
set(METAL_TEST_SOURCES gpu_tests.cpp)
endif()