From 0cae0bdac83bbf5b3d1da3ca53f1f7eb95981d30 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 7 May 2025 13:26:46 +0900 Subject: [PATCH 1/4] CUDA backend: backbone (#2075) --- CMakeLists.txt | 5 + mlx/CMakeLists.txt | 10 +- mlx/backend/cuda/CMakeLists.txt | 57 ++++++ mlx/backend/cuda/allocator.cpp | 154 ++++++++++++++ mlx/backend/cuda/allocator.h | 58 ++++++ mlx/backend/cuda/copy.cpp | 26 +++ mlx/backend/cuda/device.cpp | 117 +++++++++++ mlx/backend/cuda/device.h | 131 ++++++++++++ mlx/backend/cuda/dtype_utils.cuh | 35 ++++ mlx/backend/cuda/eval.cpp | 68 +++++++ mlx/backend/cuda/event.cu | 265 +++++++++++++++++++++++++ mlx/backend/cuda/event.h | 66 ++++++ mlx/backend/cuda/fence.cu | 70 +++++++ mlx/backend/cuda/kernels/arange.cuh | 15 ++ mlx/backend/cuda/kernels/fp16_math.cuh | 107 ++++++++++ mlx/backend/cuda/primitives.cu | 163 +++++++++++++++ mlx/backend/cuda/slicing.cpp | 15 ++ mlx/backend/cuda/utils.cpp | 26 +++ mlx/backend/cuda/utils.h | 36 ++++ mlx/backend/cuda/worker.cpp | 90 +++++++++ mlx/backend/cuda/worker.h | 68 +++++++ tests/CMakeLists.txt | 2 +- 22 files changed, 1582 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/CMakeLists.txt create mode 100644 mlx/backend/cuda/allocator.cpp create mode 100644 mlx/backend/cuda/allocator.h create mode 100644 mlx/backend/cuda/copy.cpp create mode 100644 mlx/backend/cuda/device.cpp create mode 100644 mlx/backend/cuda/device.h create mode 100644 mlx/backend/cuda/dtype_utils.cuh create mode 100644 mlx/backend/cuda/eval.cpp create mode 100644 mlx/backend/cuda/event.cu create mode 100644 mlx/backend/cuda/event.h create mode 100644 mlx/backend/cuda/fence.cu create mode 100644 mlx/backend/cuda/kernels/arange.cuh create mode 100644 mlx/backend/cuda/kernels/fp16_math.cuh create mode 100644 mlx/backend/cuda/primitives.cu create mode 100644 mlx/backend/cuda/slicing.cpp create mode 100644 mlx/backend/cuda/utils.cpp create mode 100644 mlx/backend/cuda/utils.h create mode 100644 mlx/backend/cuda/worker.cpp create mode 100644 mlx/backend/cuda/worker.h diff --git a/CMakeLists.txt b/CMakeLists.txt index e2002fc94..ab8aea443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) +option(MLX_BUILD_CUDA "Build cuda backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -83,6 +84,10 @@ if(MLX_BUILD_METAL) set(QUARTZ_LIB "-framework QuartzCore") endif() +if(MLX_BUILD_CUDA) + enable_language(CUDA) +endif() + if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 00898e73e..4ba9b33dd 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,10 +47,18 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) +endif() + +if(MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) +endif() + +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) +else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt new file mode 100644 index 000000000..54d651005 --- /dev/null +++ b/mlx/backend/cuda/CMakeLists.txt @@ -0,0 +1,57 @@ +# Filename rules in cuda backend: +# +# * Use .cu/.cuh if code contains device code, and .cpp/.h if not. +# * Device-only kernel code should be put in kernels/ subdir. +# * Files in kernels/ subdir should not include files outside. +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cu + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) + +target_compile_definitions(mlx PUBLIC MLX_USE_CUDA) + +# Enable defining device lambda functions. +target_compile_options(mlx + PRIVATE "$<$:--extended-lambda>") + +# Compute capability 7 is required for synchronization between CPU/GPU with +# managed memory. TODO: Add more architectures for potential performance gain. +set(MLX_CUDA_ARCHITECTURES + "75;80" + CACHE STRING "CUDA architectures") +message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") +set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES + "${MLX_CUDA_ARCHITECTURES}") + +# Use fixed version of CCCL. +FetchContent_Declare( + cccl + URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") +FetchContent_MakeAvailable(cccl) +target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include") + +# Use fixed version of NVTX. +FetchContent_Declare( + nvtx3 + GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git + GIT_TAG v3.1.1 + GIT_SHALLOW TRUE + SOURCE_SUBDIR c EXCLUDE_FROM_ALL) +FetchContent_MakeAvailable(nvtx3) +target_link_libraries(mlx PUBLIC $) + +# Make cuda runtime APIs available in non-cuda files. +find_package(CUDAToolkit REQUIRED) +target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + +# Suppress nvcc warnings on MLX headers. +target_compile_options(mlx PRIVATE $<$:-Xcudafe + --diag_suppress=997>) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp new file mode 100644 index 000000000..203534e21 --- /dev/null +++ b/mlx/backend/cuda/allocator.cpp @@ -0,0 +1,154 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/worker.h" + +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +CudaAllocator::CudaAllocator() { + // TODO: Set memory limit for multi-device. + size_t free, total; + CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); + memory_limit_ = total * 0.8; +} + +Buffer CudaAllocator::malloc(size_t size) { + // TODO: Check memory limit. + auto* buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error( + fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + std::lock_guard lock(mutex_); + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{buf}; +} + +void CudaAllocator::free(Buffer buffer) { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return; + } + + // If free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([buffer]() { allocator().free(buffer); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + size_t size = buf->size; + cudaFree(buf->data); + delete buf; + std::lock_guard lock(mutex_); + active_memory_ -= size; +} + +size_t CudaAllocator::size(Buffer buffer) const { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return 0; + } + return buf->size; +} + +void CudaAllocator::register_this_thread() { + std::lock_guard lock(worker_mutex_); + allowed_threads_.insert(std::this_thread::get_id()); +} + +size_t CudaAllocator::get_active_memory() const { + return active_memory_; +} + +size_t CudaAllocator::get_peak_memory() const { + return peak_memory_; +} + +void CudaAllocator::reset_peak_memory() { + std::lock_guard lock(mutex_); + peak_memory_ = 0; +} + +size_t CudaAllocator::get_memory_limit() { + return memory_limit_; +} + +size_t CudaAllocator::set_memory_limit(size_t limit) { + std::lock_guard lock(mutex_); + std::swap(limit, memory_limit_); + return limit; +} + +CudaAllocator& allocator() { + // By creating the |allocator_| on heap, the destructor of CudaAllocator + // will not be called on exit and buffers in the cache will be leaked. This + // can save some time at program exit. + static CudaAllocator* allocator_ = new CudaAllocator; + return *allocator_; +} + +} // namespace cu + +namespace allocator { + +Allocator& allocator() { + return cu::allocator(); +} + +void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } + return static_cast(ptr_)->data; +} + +} // namespace allocator + +size_t get_active_memory() { + return cu::allocator().get_active_memory(); +} +size_t get_peak_memory() { + return cu::allocator().get_peak_memory(); +} +void reset_peak_memory() { + return cu::allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return cu::allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return cu::allocator().get_memory_limit(); +} + +// TODO: Implement buffer cache. +size_t get_cache_memory() { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; +} +size_t set_wired_limit(size_t) { + return 0; +} +void clear_cache() {} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h new file mode 100644 index 000000000..6c418ee7e --- /dev/null +++ b/mlx/backend/cuda/allocator.h @@ -0,0 +1,58 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +class Worker; + +using allocator::Buffer; + +// Stores cuda-managed unified memory. +struct CudaBuffer { + void* data; + size_t size; +}; + +class CudaAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // Register current thread as safe to free buffers. + // In cuda freeing a buffer implicitly synchronizes stream, and for threads + // that may be waited by gpu stream (for example cpu stream threads), freeing + // buffers there would result in dead lock. + void register_this_thread(); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + + private: + CudaAllocator(); + friend CudaAllocator& allocator(); + + std::mutex worker_mutex_; + std::unique_ptr worker_; + std::set allowed_threads_; + + std::mutex mutex_; + size_t memory_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; +}; + +CudaAllocator& allocator(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/copy.cpp b/mlx/backend/cuda/copy.cpp new file mode 100644 index 000000000..d0413d989 --- /dev/null +++ b/mlx/backend/cuda/copy.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& data_shape, + const Strides& strides_in_pre, + const Strides& strides_out_pre, + int64_t inp_offset, + int64_t out_offset, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_i_offset /* = std::nullopt */, + const std::optional& dynamic_o_offset /* = std::nullopt */) { + throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); +} + +void fill_gpu(const array& val, array& out, const Stream& s) { + throw std::runtime_error("fill_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp new file mode 100644 index 000000000..a28ffa35e --- /dev/null +++ b/mlx/backend/cuda/device.cpp @@ -0,0 +1,117 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/metal/metal.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} + +void DeviceStream::synchronize() { + cudaStreamSynchronize(stream_); +} + +cudaStream_t DeviceStream::schedule_cuda_stream() { + // TODO: Return a stream that maximizes parallelism. + return stream_; +} + +cudaStream_t DeviceStream::last_cuda_stream() { + return stream_; +} + +CommandEncoder& DeviceStream::get_encoder() { + if (!encoder_) { + encoder_ = std::make_unique(*this); + } + return *encoder_; +} + +Device::Device(int device) : device_(device) { + // Validate the requirements of device. + int attr = 0; + cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_); + if (attr != 1) { + throw std::runtime_error(fmt::format( + "Device {} does not support synchronization in managed memory.", + device_)); + } +} + +void Device::make_current() { + // We need to set/get current CUDA device very frequently, cache it to reduce + // actual calls of CUDA APIs. This function assumes single-thread in host. + static int current = 0; + if (current != device_) { + CHECK_CUDA_ERROR(cudaSetDevice(device_)); + current = device_; + } +} + +DeviceStream& Device::get_stream(Stream s) { + auto it = streams_.find(s.index); + if (it == streams_.end()) { + it = streams_.try_emplace(s.index, *this).first; + } + return it->second; +} + +CommandEncoder::CommandEncoder(DeviceStream& s) + : device_(s.device()), stream_(s) {} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_.add_task(std::move(task)); +} + +void CommandEncoder::end_encoding() { + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + + // There is no kernel running, run completion handlers immediately. + if (!has_gpu_work_) { + worker_.consume_in_this_thread(); + return; + } + has_gpu_work_ = false; + + // Put completion handlers in a batch. + worker_.end_batch(); + + // Signaling kernel completion is expensive, delay until enough batches. + // TODO: This number is arbitrarily picked, profile for a better stragety. + if (worker_.uncommited_batches() > 8) { + commit(); + } +} + +void CommandEncoder::commit() { + worker_.commit(stream_.last_cuda_stream()); +} + +Device& device(mlx::core::Device device) { + static std::unordered_map devices; + auto it = devices.find(device.index); + if (it == devices.end()) { + it = devices.try_emplace(device.index, device.index).first; + } + return it->second; +} + +DeviceStream& get_stream(Stream s) { + return device(s.device).get_stream(s); +} + +CommandEncoder& get_command_encoder(Stream s) { + return get_stream(s).get_encoder(); +} + +} // namespace cu + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h new file mode 100644 index 000000000..a65a87d54 --- /dev/null +++ b/mlx/backend/cuda/device.h @@ -0,0 +1,131 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/stream.h" + +#include + +#include + +namespace mlx::core::cu { + +class Device; +class CommandEncoder; + +class DeviceStream { + public: + explicit DeviceStream(Device& device); + + DeviceStream(const DeviceStream&) = delete; + DeviceStream& operator=(const DeviceStream&) = delete; + + // Wait until kernels in the stream complete. + void synchronize(); + + // Return a cuda stream for launching kernels. + cudaStream_t schedule_cuda_stream(); + + // Return the last cuda stream used. + cudaStream_t last_cuda_stream(); + + CommandEncoder& get_encoder(); + + Device& device() { + return device_; + } + + private: + Device& device_; + CudaStream stream_; + std::unique_ptr encoder_; +}; + +class Device { + public: + explicit Device(int device); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current cuda device, required by some cuda calls. + void make_current(); + + DeviceStream& get_stream(Stream s); + + int cuda_device() const { + return device_; + } + + private: + int device_; + std::unordered_map streams_; +}; + +class CommandEncoder { + public: + explicit CommandEncoder(DeviceStream& stream); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + void set_input_array(const array& arr) {} + void set_output_array(const array& arr) {} + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void end_encoding(); + void commit(); + + // Schedule a cuda stream for |fun| to launch kernels, and check error + // afterwards. + template + void launch_kernel(F&& fun) { + launch_kernel(stream_.schedule_cuda_stream(), std::forward(fun)); + } + + template + void launch_kernel(cudaStream_t stream, F&& fun) { + device_.make_current(); + fun(stream); + check_cuda_error("kernel launch", cudaGetLastError()); + has_gpu_work_ = true; + } + + Device& device() { + return device_; + } + + DeviceStream& stream() { + return stream_; + } + + bool has_gpu_work() const { + return has_gpu_work_; + } + + private: + Device& device_; + DeviceStream& stream_; + Worker worker_; + bool has_gpu_work_{false}; + std::vector> temporaries_; +}; + +Device& device(mlx::core::Device device); +DeviceStream& get_stream(Stream s); +CommandEncoder& get_command_encoder(Stream s); + +// Return an execution policy that does not sync for result. +// Note that not all thrust APIs support async policy, confirm before using. +inline auto thrust_policy(cudaStream_t stream) { + // TODO: Connect thrust's custom allocator with mlx's allocator. + return thrust::cuda::par_nosync.on(stream); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/dtype_utils.cuh b/mlx/backend/cuda/dtype_utils.cuh new file mode 100644 index 000000000..9b7f8ba65 --- /dev/null +++ b/mlx/backend/cuda/dtype_utils.cuh @@ -0,0 +1,35 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +// Maps CPU types to CUDA types. +template +struct CTypeToCudaType { + using type = T; +}; + +template <> +struct CTypeToCudaType { + using type = __half; +}; + +template <> +struct CTypeToCudaType { + using type = __nv_bfloat16; +}; + +template <> +struct CTypeToCudaType { + using type = cuComplex; +}; + +template +using cuda_type_t = typename CTypeToCudaType::type; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp new file mode 100644 index 000000000..b309ad60e --- /dev/null +++ b/mlx/backend/cuda/eval.cpp @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream s) { + // Force initalization of cuda, so cuda runtime get destroyed at last. + cudaFree(nullptr); + // Ensure the static stream objects get created. + cu::get_command_encoder(s); + // The main thread is safe to free buffers. + cu::allocator().register_this_thread(); +} + +void eval(array& arr) { + nvtx3::scoped_range r("gpu::eval"); + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + + auto& encoder = cu::get_command_encoder(arr.primitive().stream()); + if (encoder.has_gpu_work()) { + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + } + encoder.end_encoding(); +} + +void finalize(Stream s) { + nvtx3::scoped_range r("gpu::finalize"); + cu::get_command_encoder(s).commit(); +} + +void synchronize(Stream s) { + nvtx3::scoped_range r("gpu::synchronize"); + cu::get_stream(s).synchronize(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu new file mode 100644 index 000000000..a487f45b4 --- /dev/null +++ b/mlx/backend/cuda/event.cu @@ -0,0 +1,265 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/event.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace cu { + +/////////////////////////////////////////////////////////////////////////////// +// CudaEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +// Cuda event managed with RAII. +class CudaEventHandle { + public: + CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventCreateWithFlags( + &event_, cudaEventDisableTiming | cudaEventBlockingSync)); + } + + ~CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventDestroy(event_)); + } + + CudaEventHandle(const CudaEventHandle&) = delete; + CudaEventHandle& operator=(const CudaEventHandle&) = delete; + + operator cudaEvent_t() const { + return event_; + } + + private: + cudaEvent_t event_; +}; + +CudaEvent::CudaEvent() : event_(std::make_shared()) {} + +void CudaEvent::wait() { + nvtx3::scoped_range r("cu::CudaEvent::wait"); + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaEventSynchronize(*event_); +} + +void CudaEvent::wait(cudaStream_t stream) { + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaStreamWaitEvent(stream, *event_); +} + +void CudaEvent::wait(Stream s) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this]() mutable { wait(); }); + } else { + wait(cu::get_stream(s).last_cuda_stream()); + } +} + +void CudaEvent::record(cudaStream_t stream) { + cudaEventRecord(*event_, stream); + recorded_ = true; +} + +void CudaEvent::record(Stream s) { + if (s.device == mlx::core::Device::cpu) { + throw std::runtime_error("CudaEvent can not wait on cpu stream."); + } else { + record(cu::get_stream(s).last_cuda_stream()); + } +} + +bool CudaEvent::completed() const { + return cudaEventQuery(*event_) == cudaSuccess; +} + +/////////////////////////////////////////////////////////////////////////////// +// SharedEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { + uint64_t current; + while ((current = ac->load()) < value) { + ac->wait(current); + } +} + +__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { + ac->store(value); + ac->notify_all(); +} + +__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_wait(ac, value); +} + +__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_signal(ac, value); +} + +} // namespace + +SharedEvent::SharedEvent() { + // Allocate cuda::atomic on managed memory. + allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); + Atomic* ac = static_cast(buffer.raw_ptr()); + new (ac) Atomic(0); + ac_ = std::shared_ptr(ac, [buffer](Atomic* ptr) { + ptr->~Atomic(); + allocator::free(buffer); + }); +} + +void SharedEvent::wait(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait"); + event_wait(ac_.get(), value); +} + +void SharedEvent::wait(cudaStream_t stream, uint64_t value) { + event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::wait(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { wait(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +void SharedEvent::signal(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal"); + event_signal(ac_.get(), value); +} + +void SharedEvent::signal(cudaStream_t stream, uint64_t value) { + event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::signal(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { signal(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +bool SharedEvent::is_signaled(uint64_t value) const { + nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); + return ac_->load() >= value; +} + +uint64_t SharedEvent::value() const { + nvtx3::scoped_range r("cu::SharedEvent::value"); + return ac_->load(); +} + +} // namespace cu + +/////////////////////////////////////////////////////////////////////////////// +// Event implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct EventImpl { + // CudaEvent is preferred when possible because it is fast, however we have + // to fallback to SharedEvent in following cases: + // 1. the event is used to wait/signal a cpu stream; + // 2. signal value other than 1 has been specified. + std::unique_ptr cuda; + std::unique_ptr shared; + + bool is_created() const { + return cuda || shared; + } + + void ensure_created(Stream s, uint64_t signal_value) { + if (is_created()) { + return; + } + if (s.device == mlx::core::Device::cpu || signal_value > 1) { + nvtx3::mark("Using slow SharedEvent"); + shared = std::make_unique(); + } else { + cuda = std::make_unique(); + } + } +}; + +} // namespace + +Event::Event(Stream s) : stream_(s) { + event_ = std::shared_ptr( + new EventImpl(), [](void* ptr) { delete static_cast(ptr); }); +} + +void Event::wait() { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(); + } else { + event->shared->wait(value()); + } +} + +void Event::wait(Stream s) { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(s); + } else { + event->shared->wait(s, value()); + } +} + +void Event::signal(Stream s) { + auto* event = static_cast(event_.get()); + event->ensure_created(s, value()); + if (event->cuda) { + assert(value() == 1); + event->cuda->record(s); + } else { + event->shared->signal(s, value()); + } +} + +bool Event::is_signaled() const { + auto* event = static_cast(event_.get()); + if (!event->is_created()) { + return false; + } + if (event->cuda) { + assert(value() == 1); + return event->cuda->recorded() && event->cuda->completed(); + } else { + return event->shared->is_signaled(value()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h new file mode 100644 index 000000000..4b56e2e3b --- /dev/null +++ b/mlx/backend/cuda/event.h @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/stream.h" + +#include +#include + +#include + +namespace mlx::core::cu { + +class CudaEventHandle; + +// Wrapper of native cuda event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class CudaEvent { + public: + CudaEvent(); + + void wait(); + void wait(cudaStream_t stream); + void wait(Stream s); + void record(cudaStream_t stream); + void record(Stream s); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + bool recorded() const { + return recorded_; + } + + private: + bool recorded_{false}; + std::shared_ptr event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// CudaEvent so the latter should always be preferred when possible. +class SharedEvent { + public: + using Atomic = cuda::atomic; + + SharedEvent(); + + void wait(uint64_t value); + void wait(cudaStream_t stream, uint64_t value); + void wait(Stream s, uint64_t value); + void signal(uint64_t value); + void signal(cudaStream_t stream, uint64_t value); + void signal(Stream s, uint64_t value); + bool is_signaled(uint64_t value) const; + uint64_t value() const; + + const std::shared_ptr& atomic() const { + return ac_; + } + + private: + std::shared_ptr ac_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/fence.cu b/mlx/backend/cuda/fence.cu new file mode 100644 index 000000000..091b252c1 --- /dev/null +++ b/mlx/backend/cuda/fence.cu @@ -0,0 +1,70 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/fence.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace { + +__host__ __device__ void busy_wait(cuda::atomic* ac, uint64_t value) { + while (true) { + // In theory the atomic_thread_fence is not needed, but for CUDA 11 without + // it the load() may never return new value. + cuda::atomic_thread_fence(cuda::memory_order_seq_cst); + uint64_t current = ac->load(); + if (current >= value) { + break; + } + } +} + +__global__ void busy_wait_kernel(cuda::atomic* ac, uint64_t value) { + busy_wait(ac, value); +} + +} // namespace + +struct FenceImpl { + uint32_t count; + cu::SharedEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + // We can't use SharedEvent::wait because it could hang in CUDA 11, see also: + // https://github.com/ml-explore/mlx/issues/2137 + const auto& ac = fence->event.atomic(); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [ac, count = fence->count]() { + nvtx3::scoped_range r("Fence::wait()"); + busy_wait(ac.get(), count); + }); + } else { + nvtx3::scoped_range r("Fence::wait(s)"); + auto& encoder = cu::get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) { + busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count); + }); + encoder.add_completed_handler([ac]() {}); + encoder.end_encoding(); + } +} + +void Fence::update(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/arange.cuh b/mlx/backend/cuda/kernels/arange.cuh new file mode 100644 index 000000000..53c261e34 --- /dev/null +++ b/mlx/backend/cuda/kernels/arange.cuh @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::cu { + +template +struct Arange { + const T start; + const T step; + + __device__ T operator()(uint32_t i) const { + return start + i * step; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh new file mode 100644 index 000000000..931c55ff7 --- /dev/null +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -0,0 +1,107 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// Missing C++ operator overrides for CUDA 7. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +#define MLX_DEFINE_BF16_OP(OP) \ + __forceinline__ __device__ __nv_bfloat16 operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +#define MLX_DEFINE_BF16_CMP(OP) \ + __forceinline__ __device__ bool operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +MLX_DEFINE_BF16_OP(+) +MLX_DEFINE_BF16_OP(-) +MLX_DEFINE_BF16_OP(*) +MLX_DEFINE_BF16_OP(/) +MLX_DEFINE_BF16_CMP(>) +MLX_DEFINE_BF16_CMP(<) +MLX_DEFINE_BF16_CMP(>=) +MLX_DEFINE_BF16_CMP(<=) + +#undef MLX_DEFINE_BF16_OP +#undef MLX_DEFINE_BF16_CMP + +#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +/////////////////////////////////////////////////////////////////////////////// +// Additional C++ operator overrides between half types and native types. +/////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool is_integral_except = + cuda::std::is_integral_v && !cuda::std::is_same_v; + +template +constexpr bool is_arithmetic_except = + cuda::std::is_arithmetic_v && !cuda::std::is_same_v; + +#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(HALF x, T y) { \ + return FLOAT2HALF(HALF2FLOAT(x) OP static_cast(y)); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(T x, HALF y) { \ + return FLOAT2HALF(static_cast(x) OP HALF2FLOAT(y)); \ + } + +#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(HALF x, T y) { \ + return HALF2FLOAT(x) OP static_cast(y); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(T x, HALF y) { \ + return static_cast(y) OP HALF2FLOAT(x); \ + } + +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /) +MLX_DEFINE_HALF_CMP(__half, __half2float, <) +MLX_DEFINE_HALF_CMP(__half, __half2float, >) +MLX_DEFINE_HALF_CMP(__half, __half2float, <=) +MLX_DEFINE_HALF_CMP(__half, __half2float, >=) +MLX_DEFINE_HALF_CMP(__half, __half2float, ==) +MLX_DEFINE_HALF_CMP(__half, __half2float, !=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=) + +#undef MLX_DEFINE_HALF_OP +#undef MLX_DEFINE_HALF_CMP + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu new file mode 100644 index 000000000..dc6edf606 --- /dev/null +++ b/mlx/backend/cuda/primitives.cu @@ -0,0 +1,163 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/dtype_utils.cuh" +#include "mlx/backend/cuda/kernels/arange.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/distributed/primitives.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Arange::eval_gpu"); + assert(inputs.size() == 0); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_output_array(out); + encoder.launch_kernel([&, this](cudaStream_t stream) { + MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); + }); + }); +} + +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +NO_GPU(Abs) +NO_GPU(Add) +NO_GPU(AddMM) +NO_GPU(ArcCos) +NO_GPU(ArcCosh) +NO_GPU(ArcSin) +NO_GPU(ArcSinh) +NO_GPU(ArcTan) +NO_GPU(ArcTan2) +NO_GPU(ArcTanh) +NO_GPU(ArgPartition) +NO_GPU(ArgReduce) +NO_GPU(ArgSort) +NO_GPU(BitwiseBinary) +NO_GPU(BitwiseInvert) +NO_GPU(BlockMaskedMM) +NO_GPU(Ceil) +NO_GPU_MULTI(Compiled) +NO_GPU(Conjugate) +NO_GPU(Convolution) +NO_GPU(Cos) +NO_GPU(Cosh) +NO_GPU(Divide) +NO_GPU_MULTI(DivMod) +NO_GPU(DynamicSlice) +NO_GPU(DynamicSliceUpdate) +NO_GPU(Remainder) +NO_GPU(Equal) +NO_GPU(Erf) +NO_GPU(ErfInv) +NO_GPU(Exp) +NO_GPU(Expm1) +NO_GPU(FFT) +NO_GPU(Floor) +NO_GPU(Gather) +NO_GPU(GatherAxis) +NO_GPU(GatherMM) +NO_GPU(GatherQMM) +NO_GPU(Greater) +NO_GPU(GreaterEqual) +NO_GPU(Hadamard) +NO_GPU(Imag) +NO_GPU(Less) +NO_GPU(LessEqual) +NO_GPU(Load) +NO_GPU(Log) +NO_GPU(Log1p) +NO_GPU(LogicalNot) +NO_GPU(LogicalAnd) +NO_GPU(LogicalOr) +NO_GPU(LogAddExp) +NO_GPU(LogSumExp) +NO_GPU_MULTI(LUF) +NO_GPU(Matmul) +NO_GPU(Maximum) +NO_GPU(Minimum) +NO_GPU(Multiply) +NO_GPU(Negative) +NO_GPU(NotEqual) +NO_GPU(Partition) +NO_GPU(Power) +NO_GPU_MULTI(QRF) +NO_GPU(QuantizedMatmul) +NO_GPU(RandomBits) +NO_GPU(Real) +NO_GPU(Reduce) +NO_GPU(Round) +NO_GPU(Scan) +NO_GPU(Scatter) +NO_GPU(ScatterAxis) +NO_GPU(Select) +NO_GPU(Sigmoid) +NO_GPU(Sign) +NO_GPU(Sin) +NO_GPU(Sinh) +NO_GPU(SliceUpdate) +NO_GPU(Softmax) +NO_GPU(Sort) +NO_GPU(Square) +NO_GPU(Sqrt) +NO_GPU(Subtract) +NO_GPU_MULTI(SVD) +NO_GPU(Tan) +NO_GPU(Tanh) +NO_GPU(Inverse) +NO_GPU(Cholesky) +NO_GPU_MULTI(Eigh) + +namespace fast { +NO_GPU_MULTI(LayerNorm) +NO_GPU_MULTI(LayerNormVJP) +NO_GPU_MULTI(RMSNorm) +NO_GPU_MULTI(RMSNormVJP) +NO_GPU_MULTI(RoPE) +NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(CustomKernel) +} // namespace fast + +namespace distributed { +NO_GPU_MULTI(AllReduce) +NO_GPU_MULTI(AllGather) +NO_GPU_MULTI(Send) +NO_GPU_MULTI(Recv) +} // namespace distributed + +} // namespace mlx::core diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp new file mode 100644 index 000000000..bfa742c74 --- /dev/null +++ b/mlx/backend/cuda/slicing.cpp @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s) { + throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp new file mode 100644 index 000000000..2a11a518e --- /dev/null +++ b/mlx/backend/cuda/utils.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); +} + +CudaStream::~CudaStream() { + CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); +} + +void check_cuda_error(const char* name, cudaError_t err) { + if (err != cudaSuccess) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, cudaGetErrorString(err))); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h new file mode 100644 index 000000000..58d508765 --- /dev/null +++ b/mlx/backend/cuda/utils.h @@ -0,0 +1,36 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +namespace cu { +class Device; +} + +// Cuda stream managed with RAII. +class CudaStream { + public: + explicit CudaStream(cu::Device& device); + ~CudaStream(); + + CudaStream(const CudaStream&) = delete; + CudaStream& operator=(const CudaStream&) = delete; + + operator cudaStream_t() const { + return stream_; + } + + private: + cudaStream_t stream_; +}; + +// Throw exception if the cuda API does not succeed. +void check_cuda_error(const char* name, cudaError_t err); + +// The macro version that prints the command that failed. +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) + +} // namespace mlx::core diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp new file mode 100644 index 000000000..64b5c7679 --- /dev/null +++ b/mlx/backend/cuda/worker.cpp @@ -0,0 +1,90 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +Worker::Worker() + : signal_stream_(device(mlx::core::Device::gpu)), + worker_(&Worker::thread_fn, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(worker_mutex_); + stop_ = true; + } + worker_event_.signal(batch_ + 1); + worker_.join(); +} + +void Worker::add_task(std::function task) { + pending_tasks_.push_back(std::move(task)); +} + +void Worker::consume_in_this_thread() { + for (auto& task : pending_tasks_) { + task(); + } + pending_tasks_.clear(); +} + +void Worker::end_batch() { + batch_++; + { + std::lock_guard lock(worker_mutex_); + worker_tasks_[batch_] = std::move(pending_tasks_); + } + uncommited_batches_++; +} + +void Worker::commit() { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + worker_event_.signal(batch_); +} + +void Worker::commit(cudaStream_t stream) { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + // Signal the |worker_event_| in |signal_stream_| after the kernels in + // |stream_| finish running. + signal_event_.record(stream); + signal_event_.wait(signal_stream_); + worker_event_.signal(signal_stream_, batch_); +} + +void Worker::thread_fn() { + // The worker thread is safe to free buffers. + allocator().register_this_thread(); + + while (!stop_) { + uint64_t batch = worker_event_.value(); + Tasks tasks; + { + std::lock_guard lock(worker_mutex_); + // Move tasks in signaled batches. + auto end = worker_tasks_.upper_bound(batch); + for (auto it = worker_tasks_.begin(); it != end; ++it) { + if (tasks.empty()) { + tasks = std::move(it->second); + } else { + std::move( + it->second.begin(), it->second.end(), std::back_inserter(tasks)); + } + } + worker_tasks_.erase(worker_tasks_.begin(), end); + } + for (auto& task : tasks) { + task(); + } + worker_event_.wait(batch + 1); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h new file mode 100644 index 000000000..d28e22e95 --- /dev/null +++ b/mlx/backend/cuda/worker.h @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +// Run tasks in worker thread, synchronized with cuda stream. +class Worker { + public: + Worker(); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or commited. + void add_task(std::function task); + + // Run pending tasks immediately in current thread. + void consume_in_this_thread(); + + // Put pending tasks in a batch. + void end_batch(); + + // Inform worker thread to run current batches now. + void commit(); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(cudaStream_t stream); + + // Return how many batches have been added but not committed yet. + size_t uncommited_batches() const { + return uncommited_batches_; + } + + private: + void thread_fn(); + + uint64_t batch_{0}; + size_t uncommited_batches_{0}; + + // Cuda stream and event for signaling kernel completion. + CudaStream signal_stream_; + CudaEvent signal_event_; + + // Worker thread. + SharedEvent worker_event_; + std::thread worker_; + std::mutex worker_mutex_; + bool stop_{false}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; +}; + +} // namespace mlx::core::cu diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cf0ba3d5d..cb174865d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,7 +9,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) -if(MLX_BUILD_METAL) +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) set(METAL_TEST_SOURCES gpu_tests.cpp) endif() From a7fae8a176fad114c89ca66ed0e0be8f3064e3e8 Mon Sep 17 00:00:00 2001 From: ATurker <53705368+aturker1@users.noreply.github.com> Date: Fri, 9 May 2025 20:26:52 +0300 Subject: [PATCH 2/4] fix: conv_general differences between gpu, cpu (#2070) * fix general_conv padding * fix bugs * add test --------- Co-authored-by: Awni Hannun --- mlx/backend/cpu/conv.cpp | 574 +++++++++++++++++++++---------------- mlx/backend/metal/conv.cpp | 6 +- mlx/ops.cpp | 1 + mlx/primitives.cpp | 48 ++-- mlx/primitives.h | 12 +- python/tests/test_conv.py | 42 +++ 6 files changed, 413 insertions(+), 270 deletions(-) diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index d52f92f8b..e5636b3b8 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -22,7 +22,8 @@ void slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -60,7 +61,8 @@ void slow_conv_1D( out_stride_O = out.strides()[2], flip, - padding = padding[0], + padding_lo = padding_lo[0], + padding_hi = padding_hi[0], wt_stride = wt_strides[0], wt_dilation = wt_dilation[0], in_dilation = in_dilation[0]]() mutable { @@ -77,7 +79,7 @@ void slow_conv_1D( const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; int wh_flip = flip ? (wH - wh - 1) : wh; - int ih = oh * wt_stride - padding + wh_flip * wt_dilation; + int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation; auto ih_div = std::div(ih, in_dilation); @@ -109,7 +111,8 @@ void slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -120,230 +123,235 @@ void slow_conv_2D( encoder.set_input_array(wt); encoder.set_output_array(out); - encoder.dispatch([st_wt_ptr = wt.data(), - st_in_ptr = in.data(), - st_out_ptr = out.data(), + encoder.dispatch( + [st_wt_ptr = wt.data(), + st_in_ptr = in.data(), + st_out_ptr = out.data(), - N = in.shape( - 0), // Batch size, should be the same as out.shape(0) - iH = 1 + - in_dilation[0] * (in.shape(1) - 1), // Input spatial dim - iW = 1 + - in_dilation[1] * (in.shape(2) - 1), // Input spatial dim - C = in.shape(3), // In channels - oH = out.shape(1), // Output spatial dim - oW = out.shape(2), // Output spatial dim - O = wt.shape(0), // Out channels - wH = wt.shape(1), // Weight spatial dim - wW = wt.shape(2), // Weight spatial dim + N = in.shape(0), // Batch size, should be the same as out.shape(0) + iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim + C = in.shape(3), // In channels + oH = out.shape(1), // Output spatial dim + oW = out.shape(2), // Output spatial dim + O = wt.shape(0), // Out channels + wH = wt.shape(1), // Weight spatial dim + wW = wt.shape(2), // Weight spatial dim - groups = in.shape(3) / wt.shape(3), - C_per_group = wt.shape(3), + groups = in.shape(3) / wt.shape(3), + C_per_group = wt.shape(3), - in_stride_N = in.strides()[0], - in_stride_H = in.strides()[1], - in_stride_W = in.strides()[2], - in_stride_C = in.strides()[3], + in_stride_N = in.strides()[0], + in_stride_H = in.strides()[1], + in_stride_W = in.strides()[2], + in_stride_C = in.strides()[3], - wt_stride_O = wt.strides()[0], - wt_stride_H = wt.strides()[1], - wt_stride_W = wt.strides()[2], - wt_stride_C = wt.strides()[3], + wt_stride_O = wt.strides()[0], + wt_stride_H = wt.strides()[1], + wt_stride_W = wt.strides()[2], + wt_stride_C = wt.strides()[3], - out_stride_N = out.strides()[0], - out_stride_H = out.strides()[1], - out_stride_W = out.strides()[2], - out_stride_O = out.strides()[3], + out_stride_N = out.strides()[0], + out_stride_H = out.strides()[1], + out_stride_W = out.strides()[2], + out_stride_O = out.strides()[3], - padding, - wt_strides, - wt_dilation, - in_dilation, - flip]() mutable { - bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip]() mutable { + bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; - const int O_per_group = O / groups; - auto pt_conv_no_checks = [&](const T* in_ptr, - const T* wt_ptr, - T* out_ptr, - int oh, - int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + const int O_per_group = O / groups; + auto pt_conv_no_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + ih * in_stride_H + iw * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c - } // ww - } // wh + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; - int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; + int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; - int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); - int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); + int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); - int f_wgt_jump_h = - std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; - int f_wgt_jump_w = - std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int f_wgt_jump_h = + std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_w = + std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; - int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; - int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_out_jump_h = + std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_w = + std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; - std::vector base_h(f_out_jump_h); - std::vector base_w(f_out_jump_w); + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); - for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[0] - padding[0] + init_h; + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h; - int wh_base = 0; - while (wh_base < wH && ih_loop % in_dilation[0] != 0) { - wh_base++; - ih_loop += jump_h; - } + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[0] != 0) { + wh_base++; + ih_loop += jump_h; + } - base_h[i] = wh_base; - } + base_h[i] = wh_base; + } - for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[1] - padding[1] + init_w; + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w; - int ww_base = 0; - while (ww_base < wW && iw_loop % in_dilation[1] != 0) { - ww_base++; - iw_loop += jump_w; - } + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[1] != 0) { + ww_base++; + iw_loop += jump_w; + } - base_w[j] = ww_base; - } + base_w[j] = ww_base; + } - auto pt_conv_all_checks = - [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; + auto pt_conv_all_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - int wh_base = base_h[oh % f_out_jump_h]; - int ww_base = base_w[ow % f_out_jump_w]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; - const T* in_ptr_pt = - in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; + const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H + + iw_dil * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; - ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c - } // ih, iw check - } // ww - } // wh + } // ih, iw check + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; - int oH_border_2 = std::max( - oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); - int oH_border_3 = oH; + int oH_border_0 = 0; + int oH_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oH; + int oH_border_2 = std::max( + oH_border_1, + (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]); + int oH_border_3 = oH; - int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; - int oW_border_2 = std::max( - oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); - int oW_border_3 = oW; + int oW_border_0 = 0; + int oW_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oW; + int oW_border_2 = std::max( + oW_border_1, + (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]); + int oW_border_3 = oW; - for (int n = 0; n < N; ++n) { - // Case 1: oh might put us out of bounds - for (int oh = oH_border_0; oh < oH_border_1; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + for (int n = 0; n < N; ++n) { + // Case 1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - // Case 2: oh in bounds - for (int oh = oH_border_1; oh < oH_border_2; ++oh) { - // Case a: ow might put us out of bounds - for (int ow = oW_border_0; ow < oW_border_1; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case 2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case a: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case b: ow in bounds - for (int ow = oW_border_1; ow < oW_border_2; ++ow) { - pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case b: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case c: ow might put us out of bounds - for (int ow = oW_border_2; ow < oW_border_3; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case c: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - } // oh + } // oh - // Case 3: oh might put us out of bounds - for (int oh = oH_border_2; oh < oH_border_3; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + // Case 3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - st_in_ptr += in_stride_N; - st_out_ptr += out_stride_N; + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; - } // n - }); + } // n + }); } template @@ -351,7 +359,8 @@ void slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -400,7 +409,8 @@ void slow_conv_3D( out_stride_H = out.strides()[2], out_stride_W = out.strides()[3], out_stride_O = out.strides()[4], - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -415,9 +425,9 @@ void slow_conv_3D( int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; for (int o = 0; o < O; ++o) { float r = 0.; @@ -478,7 +488,7 @@ void slow_conv_3D( std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_d; ++i) { - int id_loop = i * wt_strides[0] - padding[0] + init_d; + int id_loop = i * wt_strides[0] - padding_lo[0] + init_d; int wd_base = 0; while (wd_base < wD && id_loop % in_dilation[0] != 0) { @@ -490,7 +500,7 @@ void slow_conv_3D( } for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[1] - padding[1] + init_h; + int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[1] != 0) { @@ -502,7 +512,7 @@ void slow_conv_3D( } for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[2] - padding[2] + init_w; + int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[2] != 0) { @@ -521,9 +531,9 @@ void slow_conv_3D( int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; int wd_base = base_d[od % f_out_jump_d]; int wh_base = base_h[oh % f_out_jump_h]; @@ -573,24 +583,30 @@ void slow_conv_3D( }; int oD_border_0 = 0; - int oD_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; + int oD_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oD; int oD_border_2 = std::max( - oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); + oD_border_1, + (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]); int oD_border_3 = oD; int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; + int oH_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oH; int oH_border_2 = std::max( - oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); + oH_border_1, + (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]); int oH_border_3 = oH; int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; + int oW_border_1 = is_idil_one + ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2]) + : oW; int oW_border_2 = std::max( - oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); + oW_border_1, + (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { @@ -658,7 +674,8 @@ void dispatch_slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -669,7 +686,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -680,7 +698,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -691,7 +710,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -707,7 +727,8 @@ void dispatch_slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -718,7 +739,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -729,7 +751,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -740,7 +763,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -756,7 +780,8 @@ void dispatch_slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -767,7 +792,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -778,7 +804,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -789,7 +816,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], C}; + Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = padding[0] * in_padded.strides()[1]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; + Shape padded_shape = { + N, + iH + padding_lo[0] + padding_hi[0], + iW + padding_lo[1] + padding_hi[1], + C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = - padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1] + + padding_lo[1] * in_padded.strides()[2]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const bool flip, @@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu( Shape padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { - padded_shape[i + 1] = iDim[i] + 2 * padding[i]; + padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i]; } padded_shape.back() = C; array in_padded(padded_shape, conv_dtype, nullptr, {}); @@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu( // Pick input slice from padded size_t data_offset = 0; - for (size_t i = 0; i < padding.size(); i++) { - data_offset += padding[i] * in_padded.strides()[i + 1]; + for (size_t i = 0; i < padding_lo.size(); i++) { + data_offset += padding_lo[i] * in_padded.strides()[i + 1]; } + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1261,7 +1297,8 @@ void conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1270,22 +1307,40 @@ void conv_1D_cpu( const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( - in, wt, out, padding, wt_strides, wt_dilation, stream); + in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1295,18 +1350,35 @@ void conv_2D_cpu( if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } - return dispatch_slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_3D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1317,11 +1389,28 @@ void conv_3D_cpu( in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } } // namespace @@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index ae31a6cff..35ed3d44e 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -952,7 +952,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -967,7 +967,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -983,7 +983,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4aa5e88b7..e8c260425 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3974,6 +3974,7 @@ array conv_general( to_stream(s), stride, padding_lo, + padding_hi, kernel_dilation, input_dilation, groups, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 7288a4885..03ca06bdd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1055,7 +1055,8 @@ array conv_weight_backward_patches( const array& wt, const array& cotan, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, StreamOrDevice s) { // Resolve Padded input shapes and strides Shape padding_starts(in.ndim(), 0); @@ -1064,9 +1065,9 @@ array conv_weight_backward_patches( // padded shape for (int i = 1; i < in.ndim() - 1; i++) { - in_padded_shape[i] += 2 * padding[i - 1]; - padding_ends[i] += padding[i - 1]; - padding_starts[i] += padding[i - 1]; + in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1]; + padding_ends[i] += padding_lo[i - 1]; + padding_starts[i] += padding_lo[i - 1]; } // padded strides (contiguous) @@ -1078,9 +1079,16 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_(padding.begin(), padding.end()); - auto in_padded = pad( - in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); + Shape padding_lo_(padding_lo.begin(), padding_lo.end()); + Shape padding_hi_(padding_hi.begin(), padding_hi.end()); + auto in_padded = + pad(in, + padded_axes, + padding_lo_, + padding_hi_, + array(0, in.dtype()), + "constant", + s); // Resolve strided patches @@ -1147,16 +1155,16 @@ std::vector Convolution::vjp( for (int a : argnums) { // Grads for input if (a == 0) { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; for (int i = 0; i < padding_lo.size(); ++i) { int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_lo[i] = wt_size - padding_[i] - 1; + padding_lo[i] = wt_size - padding_lo_[i] - 1; int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding_[i]; + padding_hi[i] = in_size - out_size + padding_hi_[i]; } // Check for negative padding @@ -1226,18 +1234,12 @@ std::vector Convolution::vjp( if (no_dilation && !flip_ && groups_ == 1) { auto grad = conv_weight_backward_patches( - in, wt, cotan, kernel_strides_, padding_, stream()); + in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; - for (int i = 0; i < padding_hi.size(); ++i) { - int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); - int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; - } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1283,7 +1285,8 @@ std::pair, std::vector> Convolution::vmap( in, w, kernel_strides_, - padding_, + padding_lo_, + padding_hi_, kernel_dilation_, input_dilation_, groups, @@ -1332,7 +1335,8 @@ std::pair, std::vector> Convolution::vmap( bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); - return padding_ == c_other.padding_ && + return padding_lo_ == c_other.padding_lo_ && + padding_hi_ == c_other.padding_hi_ && kernel_strides_ == c_other.kernel_strides_ && kernel_dilation_ == c_other.kernel_dilation_ && input_dilation_ == c_other.input_dilation_ && diff --git a/mlx/primitives.h b/mlx/primitives.h index 3753e43c5..2caed8477 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive { explicit Convolution( Stream stream, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& kernel_dilation, const std::vector& input_dilation, const int groups = 1, const bool flip = false) : UnaryPrimitive(stream), - padding_(padding), + padding_lo_(padding_lo), + padding_hi_(padding_hi), kernel_strides_(kernel_strides), kernel_dilation_(kernel_dilation), input_dilation_(input_dilation), @@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive { } private: - std::vector padding_; + std::vector padding_lo_; + std::vector padding_hi_; std::vector kernel_strides_; std::vector kernel_dilation_; std::vector input_dilation_; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 671c86a32..35dcf42ac 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase): atol=2e-5 if dtype == np.float32 else 5e-4, ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_asymmetric_padding(self): + inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32) + strides = (2, 2, 2) + + pt_out = torch.conv3d( + torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)), + torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)), + stride=strides, + padding=2, + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=strides, + padding=([0, 0, 0], [1, 1, 1]), + ) + + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + + inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32) + + pt_out = torch.conv2d( + torch.permute(torch.tensor(inputs), (0, 3, 1, 2)), + torch.permute(torch.tensor(kernel), (0, 3, 1, 2)), + stride=1, + padding=(1, 0), + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=1, + padding=([0, 0], [1, 0]), + ) + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + if __name__ == "__main__": unittest.main() From 6661387066b38ef7221d29d7dad6c25d07d6e96a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 May 2025 14:25:12 -0700 Subject: [PATCH 3/4] Fix fft for integer overflow (#2161) --- mlx/backend/metal/fft.cpp | 4 +--- mlx/backend/metal/kernels/fft/readwrite.h | 28 ++++++++++++----------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 011eb7ebb..1e23160a6 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -632,7 +632,7 @@ void fft_op( func_consts.push_back(make_int(&rader_m, 3)); // The overall number of FFTs we're going to compute for this input - int size = out.dtype() == float32 ? out.size() : in.size(); + size_t size = out.dtype() == float32 ? out.size() : in.size(); if (real && inverse && four_step_params.required) { size = out.size(); } @@ -659,8 +659,6 @@ void fft_op( // We can perform 2 RFFTs at once so the batch size is halved. batch_size = (batch_size + 2 - 1) / 2; } - int out_buffer_size = out.size(); - auto& compute_encoder = d.get_command_encoder(s.index); auto in_type_str = in.dtype() == float32 ? "float" : "float2"; auto out_type_str = out.dtype() == float32 ? "float" : "float2"; diff --git a/mlx/backend/metal/kernels/fft/readwrite.h b/mlx/backend/metal/kernels/fft/readwrite.h index f6724820d..0dc62992e 100644 --- a/mlx/backend/metal/kernels/fft/readwrite.h +++ b/mlx/backend/metal/kernels/fft/readwrite.h @@ -98,7 +98,7 @@ struct ReadWriter { } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -121,7 +121,7 @@ struct ReadWriter { } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -144,7 +144,7 @@ struct ReadWriter { // Padded IO for Bluestein's algorithm METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; @@ -161,7 +161,7 @@ struct ReadWriter { } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -283,7 +283,8 @@ template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; @@ -317,7 +318,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; @@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter::load_padded( int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -503,7 +505,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; From 659a51919fd3d70798e91e9e112075680b95556e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 May 2025 14:35:14 -0700 Subject: [PATCH 4/4] patch bump (#2162) --- mlx/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/version.h b/mlx/version.h index 8340e1e8c..c573c45c9 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 25 -#define MLX_VERSION_PATCH 1 +#define MLX_VERSION_PATCH 2 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)