diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bf8d2d3e..158170647 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CUDA "Build cuda backend" OFF) +option(MLX_BUILD_ROCM "Build ROCm backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -88,6 +89,10 @@ if(MLX_BUILD_CUDA) enable_language(CUDA) endif() +if(MLX_BUILD_ROCM) + enable_language(HIP) +endif() + if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 7aa648533..a4e6260e9 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -60,7 +60,16 @@ else() PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp) endif() -if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) +if(MLX_BUILD_ROCM) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp) +endif() + +if(MLX_BUILD_METAL + OR MLX_BUILD_CUDA + OR MLX_BUILD_ROCM) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) diff --git a/mlx/backend/rocm/CMakeLists.txt b/mlx/backend/rocm/CMakeLists.txt new file mode 100644 index 000000000..260c5128e --- /dev/null +++ b/mlx/backend/rocm/CMakeLists.txt @@ -0,0 +1,85 @@ +# Filename rules in ROCm backend: +# +# * Use .hip/.hpp if code contains device code, and .cpp/.h if not. +# * Device-only code should be put in device/ subdir. +# * Files in device/ subdir should not include files outside. +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/binary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.hip + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.hip + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip + ${CMAKE_CURRENT_SOURCE_DIR}/random.hip + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rope.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip + ${CMAKE_CURRENT_SOURCE_DIR}/sort.hip + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/unary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) + +target_compile_definitions(mlx PRIVATE MLX_USE_ROCM) + +# Embed kernel sources in binary for JIT compilation. +file( + GLOB MLX_JIT_SOURCES + RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp") +string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) +add_custom_command( + OUTPUT gen/rocm_jit_sources.h + COMMAND + ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR} + -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P + "${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake" + DEPENDS bin2h.cmake ${MLX_JIT_SOURCES}) +add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h) +add_dependencies(mlx rocm_jit_sources) +target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") + +# Find ROCm installation +find_package(hip REQUIRED) +find_package(rocblas REQUIRED) + +# Link with ROCm libraries +target_link_libraries(mlx PRIVATE hip::device roc::rocblas) + +# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906, +# gfx908, gfx90a, gfx1030, gfx1100 +set(MLX_ROCM_ARCHITECTURES + "gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100" + CACHE STRING "ROCm GPU architectures") +message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}") + +# Set GPU targets for HIP compilation +set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}") + +# Enable HIP language support +enable_language(HIP) + +# Set HIP compiler flags +target_compile_options( + mlx + PRIVATE "$<$:-fgpu-rdc>" + "$<$:-Xcompiler=-Wall>" + "$<$:-Xcompiler=-Wextra>") + +# Add ROCm include directories +target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS}) +target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS}) diff --git a/mlx/backend/rocm/allocator.cpp b/mlx/backend/rocm/allocator.cpp new file mode 100644 index 000000000..016757f12 --- /dev/null +++ b/mlx/backend/rocm/allocator.cpp @@ -0,0 +1,206 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/rocm/worker.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace rocm { + +RocmAllocator::RocmAllocator() + : buffer_cache_( + getpagesize(), + [](RocmBuffer* buf) { return buf->size; }, + [this](RocmBuffer* buf) { + rocm_free(buf->data); + delete buf; + }) { + // TODO: Set memory limit for multi-device. + size_t free, total; + CHECK_HIP_ERROR(hipMemGetInfo(&free, &total)); + memory_limit_ = total * 0.8; + max_pool_size_ = memory_limit_; +} + +Buffer RocmAllocator::malloc(size_t size) { + // Find available buffer from cache. + std::unique_lock lock(mutex_); + RocmBuffer* buf = buffer_cache_.reuse_from_cache(size); + if (!buf) { + // If we have a lot of memory pressure or are over the maximum cache size, + // try to reclaim memory from the cache. + size_t mem_required = get_active_memory() + get_cache_memory() + size; + if (mem_required >= memory_limit_) { + buffer_cache_.release_cached_buffers(mem_required - memory_limit_); + } + + lock.unlock(); + buf = new RocmBuffer{nullptr, size}; + hipError_t err = hipMallocManaged(&buf->data, size); + if (err != hipSuccess && err != hipErrorMemoryAllocation) { + throw std::runtime_error( + fmt::format("hipMallocManaged failed: {}.", hipGetErrorString(err))); + } + lock.lock(); + } + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + + // Maintain the cache below the requested limit. + if (get_cache_memory() > max_pool_size_) { + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + + return Buffer{buf}; +} + +void RocmAllocator::free(Buffer buffer) { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return; + } + + std::unique_lock lock(mutex_); + active_memory_ -= buf->size; + if (get_cache_memory() < max_pool_size_) { + buffer_cache_.recycle_to_cache(buf); + } else { + lock.unlock(); + rocm_free(buf->data); + delete buf; + } +} + +size_t RocmAllocator::size(Buffer buffer) const { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return 0; + } + return buf->size; +} + +void RocmAllocator::register_this_thread() { + std::lock_guard lock(worker_mutex_); + allowed_threads_.insert(std::this_thread::get_id()); +} + +void RocmAllocator::rocm_free(void* buf) { + // If rocm_free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([this, buf]() { this->rocm_free(buf); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + hipFree(buf); +} + +size_t RocmAllocator::get_active_memory() const { + return active_memory_; +} + +size_t RocmAllocator::get_peak_memory() const { + return peak_memory_; +} + +void RocmAllocator::reset_peak_memory() { + std::lock_guard lock(mutex_); + peak_memory_ = 0; +} + +size_t RocmAllocator::get_memory_limit() { + return memory_limit_; +} + +size_t RocmAllocator::set_memory_limit(size_t limit) { + std::lock_guard lock(mutex_); + std::swap(limit, memory_limit_); + return limit; +} + +size_t RocmAllocator::get_cache_memory() const { + return buffer_cache_.cache_size(); +} + +size_t RocmAllocator::set_cache_limit(size_t limit) { + std::lock_guard lk(mutex_); + std::swap(limit, max_pool_size_); + return limit; +} + +void RocmAllocator::clear_cache() { + std::lock_guard lk(mutex_); + buffer_cache_.clear(); +} + +RocmAllocator& allocator() { + // By creating the |allocator_| on heap, the destructor of RocmAllocator + // will not be called on exit and buffers in the cache will be leaked. This + // can save some time at program exit. + static RocmAllocator* allocator_ = new RocmAllocator; + return *allocator_; +} + +} // namespace rocm + +namespace allocator { + +Allocator& allocator() { + return rocm::allocator(); +} + +void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } + return static_cast(ptr_)->data; +} + +} // namespace allocator + +size_t get_active_memory() { + return rocm::allocator().get_active_memory(); +} +size_t get_peak_memory() { + return rocm::allocator().get_peak_memory(); +} +void reset_peak_memory() { + return rocm::allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return rocm::allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return rocm::allocator().get_memory_limit(); +} +size_t get_cache_memory() { + return rocm::allocator().get_cache_memory(); +} +size_t set_cache_limit(size_t limit) { + return rocm::allocator().set_cache_limit(limit); +} +void clear_cache() { + rocm::allocator().clear_cache(); +} + +// Not supported in ROCm. +size_t set_wired_limit(size_t) { + return 0; +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/allocator.h b/mlx/backend/rocm/allocator.h new file mode 100644 index 000000000..af1d3fb94 --- /dev/null +++ b/mlx/backend/rocm/allocator.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +class Worker; + +using allocator::Buffer; + +// Stores ROCm-managed unified memory. +struct RocmBuffer { + void* data; + size_t size; +}; + +class RocmAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // Register current thread as safe to free buffers. + // In ROCm freeing a buffer implicitly synchronizes stream, and for threads + // that may be waited by gpu stream (for example cpu stream threads), freeing + // buffers there would result in dead lock. + void register_this_thread(); + + // Call hipFree in the safe thread. + void rocm_free(void* buf); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); + + private: + RocmAllocator(); + friend RocmAllocator& allocator(); + + std::mutex worker_mutex_; + std::unique_ptr worker_; + std::set allowed_threads_; + + std::mutex mutex_; + size_t memory_limit_; + size_t max_pool_size_; + BufferCache buffer_cache_; + size_t active_memory_{0}; + size_t peak_memory_{0}; +}; + +RocmAllocator& allocator(); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/arg_reduce.hip b/mlx/backend/rocm/arg_reduce.hip new file mode 100644 index 000000000..068625b35 --- /dev/null +++ b/mlx/backend/rocm/arg_reduce.hip @@ -0,0 +1,28 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void argmax_kernel(float* input, int* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Simple argmax placeholder + if (idx == 0) { + int max_idx = 0; + float max_val = input[0]; + for (int i = 1; i < n; i++) { + if (input[i] > max_val) { + max_val = input[i]; + max_idx = i; + } + } + output[0] = max_idx; + } +} + +void launch_argmax(float* input, int* output, int n, hipStream_t stream) { + hipLaunchKernelGGL(argmax_kernel, dim3(1), dim3(1), 0, stream, input, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/bin2h.cmake b/mlx/backend/rocm/bin2h.cmake new file mode 100644 index 000000000..1766b27c9 --- /dev/null +++ b/mlx/backend/rocm/bin2h.cmake @@ -0,0 +1,47 @@ +# Copyright © 2025 Apple Inc. + +# Script to embed kernel source files as header for JIT compilation + +set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h") +set(MLX_KERNEL_HEADER + "#pragma once\n\n#include \n#include \n\nnamespace mlx::core::rocm {\n\n" +) +set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n") + +# Create output directory +get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY) +file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR}) + +# Write header +file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER}) + +# Process JIT sources +string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES}) + +set(MLX_SOURCE_MAP + "const std::unordered_map kernel_sources = {\n") + +foreach(source IN LISTS MLX_JIT_SOURCES_LIST) + set(source_file "${MLX_SOURCE_ROOT}/${source}") + if(EXISTS ${source_file}) + # Read source file + file(READ ${source_file} source_content) + + # Escape content for C++ string literal + string(REPLACE "\\" "\\\\" source_content "${source_content}") + string(REPLACE "\"" "\\\"" source_content "${source_content}") + string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}") + + # Add to map + set(MLX_SOURCE_MAP + "${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n") + endif() +endforeach() + +set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n") + +# Write source map +file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP}) + +# Write footer +file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER}) diff --git a/mlx/backend/rocm/binary.hip b/mlx/backend/rocm/binary.hip new file mode 100644 index 000000000..8976befa2 --- /dev/null +++ b/mlx/backend/rocm/binary.hip @@ -0,0 +1,312 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/binary_ops.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +namespace cg = cooperative_groups; + +template +__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[0]); + } +} + +template +__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[index]); + } +} + +template +__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[0]); + } +} + +template +__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[index]); + } +} + +template +__global__ void binary_g_nd( + const In* a, + const In* b, + Out* out, + IdxT size, + const hip_array shape, + const hip_array a_strides, + const hip_array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out, + IdxT size, + const hip_array shape, + const hip_array a_strides, + const hip_array b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +// Binary operation support checking +template +constexpr bool supports_binary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && std::is_integral_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + return false; +} + +} // namespace rocm + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out = outputs[0]; + if (out.size() == 0) { + return; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (rocm::supports_binary_op()) { + using InType = hip_type_t; + using OutType = hip_type_t; + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + bool large = a.data_size() > INT32_MAX || + b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = + &rocm::binary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream, + a.data(), + b.data(), + out.data(), + out.size(), + make_hip_array(shape), + make_hip_array(a_strides), + make_hip_array(b_strides)); + }); + } else { + auto kernel = rocm::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream, + a.data(), + b.data(), + out.data(), + out.size(), + make_hip_array(shape), + make_hip_array(a_strides), + make_hip_array(b_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = rocm::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = rocm::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = rocm::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = rocm::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); + hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream, + a.data(), + b.data(), + out.data(), + out.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +template +void binary_op_gpu( + const std::vector& inputs, + array& out, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + std::vector outputs{out}; + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +#define BINARY_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + auto& s = outputs[0].primitive().stream(); \ + binary_op_gpu(inputs, outputs, get_primitive_string(this), s); \ + } + +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU(Remainder) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(LogAddExp) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Subtract) + +void Equal::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + if (equal_nan_) { + binary_op_gpu(inputs, out, op, s); + } else { + binary_op_gpu(inputs, out, op, s); + } +} + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, op, s); + break; + } +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/compiled.cpp b/mlx/backend/rocm/compiled.cpp new file mode 100644 index 000000000..a41bc433c --- /dev/null +++ b/mlx/backend/rocm/compiled.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void compile() { + // Placeholder for ROCm compilation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/copy.hip b/mlx/backend/rocm/copy.hip new file mode 100644 index 000000000..4419a2db2 --- /dev/null +++ b/mlx/backend/rocm/copy.hip @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void copy_kernel(float* src, float* dst, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + dst[idx] = src[idx]; + } +} + +void launch_copy(float* src, float* dst, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(copy_kernel, dim3(blocks), dim3(threads), 0, stream, src, dst, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/copy/copy.hpp b/mlx/backend/rocm/copy/copy.hpp new file mode 100644 index 000000000..1747dded2 --- /dev/null +++ b/mlx/backend/rocm/copy/copy.hpp @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// Copy function declarations +void copy_contiguous( + const void* src, + void* dst, + size_t size, + hipStream_t stream); + +void copy_general( + const void* src, + void* dst, + const int* src_shape, + const size_t* src_strides, + const int* dst_shape, + const size_t* dst_strides, + int ndim, + size_t size, + size_t dtype_size, + hipStream_t stream); + +void copy_general_dynamic( + const void* src, + void* dst, + const int* src_shape, + const size_t* src_strides, + const int* dst_shape, + const size_t* dst_strides, + int ndim, + size_t size, + size_t dtype_size, + hipStream_t stream); + +void copy_general_input( + const void* src, + void* dst, + const int* src_shape, + const size_t* src_strides, + const int* dst_shape, + const size_t* dst_strides, + int ndim, + size_t size, + size_t dtype_size, + hipStream_t stream); + +// Utility functions for element location calculation +__device__ size_t +elem_to_loc(size_t elem, const int* shape, const size_t* strides, int ndim); + +__device__ size_t +loc_to_elem(size_t loc, const int* shape, const size_t* strides, int ndim); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/copy/copy_contiguous.hip b/mlx/backend/rocm/copy/copy_contiguous.hip new file mode 100644 index 000000000..9ddac5800 --- /dev/null +++ b/mlx/backend/rocm/copy/copy_contiguous.hip @@ -0,0 +1,38 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/copy/copy.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" + +#include + +namespace mlx::core::rocm { + +__global__ void copy_contiguous_kernel( + const char* src, + char* dst, + size_t size) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < size) { + dst[tid] = src[tid]; + } +} + +void copy_contiguous( + const void* src, + void* dst, + size_t size, + hipStream_t stream) { + if (size == 0) { + return; + } + + const int threads_per_block = 256; + const int blocks = (size + threads_per_block - 1) / threads_per_block; + + copy_contiguous_kernel<<>>( + static_cast(src), + static_cast(dst), + size); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device.cpp b/mlx/backend/rocm/device.cpp new file mode 100644 index 000000000..88fb997bc --- /dev/null +++ b/mlx/backend/rocm/device.cpp @@ -0,0 +1,130 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/backend/rocm/worker.h" + +#include + +namespace mlx::core { + +namespace rocm { + +DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} + +void DeviceStream::synchronize() { + CHECK_HIP_ERROR(hipStreamSynchronize(stream_)); +} + +hipStream_t DeviceStream::schedule_hip_stream() { + // TODO: Return a stream that maximizes parallelism. + return stream_; +} + +hipStream_t DeviceStream::last_hip_stream() { + return stream_; +} + +CommandEncoder& DeviceStream::get_encoder() { + if (!encoder_) { + encoder_ = std::make_unique(*this); + } + return *encoder_; +} + +Device::Device(int device) : device_(device) { + CHECK_HIP_ERROR(hipDeviceGetAttribute( + &compute_capability_major_, + hipDeviceAttributeComputeCapabilityMajor, + device_)); + CHECK_HIP_ERROR(hipDeviceGetAttribute( + &compute_capability_minor_, + hipDeviceAttributeComputeCapabilityMinor, + device_)); + + // Validate device requirements + int attr = 0; + CHECK_HIP_ERROR(hipDeviceGetAttribute( + &attr, hipDeviceAttributeConcurrentManagedAccess, device_)); + if (attr != 1) { + // ROCm unified memory might not be available on all devices + // This is a warning rather than an error for ROCm + // TODO: Add proper ROCm unified memory checking + } + + // Create rocBLAS handle + make_current(); + CHECK_HIP_ERROR( + static_cast(rocblas_create_handle(&rocblas_handle_))); +} + +Device::~Device() { + if (rocblas_handle_) { + rocblas_destroy_handle(rocblas_handle_); + } +} + +void Device::make_current() { + // Cache current device to reduce HIP API calls + static int current = 0; + if (current != device_) { + CHECK_HIP_ERROR(hipSetDevice(device_)); + current = device_; + } +} + +DeviceStream& Device::get_stream(Stream s) { + auto it = streams_.find(s.index); + if (it == streams_.end()) { + it = streams_.try_emplace(s.index, *this).first; + } + return it->second; +} + +CommandEncoder::CommandEncoder(DeviceStream& s) + : device_(s.device()), stream_(s) {} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_.add_task(std::move(task)); +} + +void CommandEncoder::end_encoding() { + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + + // There is no kernel running, run completion handlers immediately. + if (!has_gpu_work_) { + worker_.consume_in_this_thread(); + return; + } + has_gpu_work_ = false; + + // Commit tasks + commit(); +} + +void CommandEncoder::commit() { + worker_.commit(stream_.last_hip_stream()); +} + +Device& device(mlx::core::Device device) { + static std::unordered_map devices; + auto it = devices.find(device.index); + if (it == devices.end()) { + it = devices.try_emplace(device.index, device.index).first; + } + return it->second; +} + +DeviceStream& get_stream(Stream s) { + return device(s.device).get_stream(s); +} + +CommandEncoder& get_command_encoder(Stream s) { + return get_stream(s).get_encoder(); +} + +} // namespace rocm + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/device.h b/mlx/backend/rocm/device.h new file mode 100644 index 000000000..6a9c18a07 --- /dev/null +++ b/mlx/backend/rocm/device.h @@ -0,0 +1,146 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/rocm/worker.h" +#include "mlx/stream.h" + +#include +#include + +#include + +namespace mlx::core { + +namespace rocm { + +class Device; +class CommandEncoder; + +class DeviceStream { + public: + explicit DeviceStream(Device& device); + + DeviceStream(const DeviceStream&) = delete; + DeviceStream& operator=(const DeviceStream&) = delete; + + // Wait until kernels in the stream complete. + void synchronize(); + + // Return a HIP stream for launching kernels. + hipStream_t schedule_hip_stream(); + + // Return the last HIP stream used. + hipStream_t last_hip_stream(); + + CommandEncoder& get_encoder(); + + Device& device() { + return device_; + } + + private: + Device& device_; + HipStream stream_; + std::unique_ptr encoder_; +}; + +class Device { + public: + explicit Device(int device); + ~Device(); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current HIP device, required by some HIP calls. + void make_current(); + + DeviceStream& get_stream(Stream s); + + int hip_device() const { + return device_; + } + int compute_capability_major() const { + return compute_capability_major_; + } + int compute_capability_minor() const { + return compute_capability_minor_; + } + rocblas_handle rocblas_handle() const { + return rocblas_handle_; + } + + private: + int device_; + int compute_capability_major_; + int compute_capability_minor_; + rocblas_handle rocblas_handle_; + std::unordered_map streams_; +}; + +class CommandEncoder { + public: + explicit CommandEncoder(DeviceStream& stream); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + void set_input_array(const array& arr) {} + void set_output_array(const array& arr) {} + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void end_encoding(); + void commit(); + + // Schedule a HIP stream for |fun| to launch kernels, and check error + // afterwards. + template + void launch_kernel(F&& fun) { + launch_kernel(stream_.schedule_hip_stream(), std::forward(fun)); + } + + template + void launch_kernel(hipStream_t stream, F&& fun) { + device_.make_current(); + fun(stream); + check_hip_error("kernel launch", hipGetLastError()); + has_gpu_work_ = true; + } + + Device& device() { + return device_; + } + + DeviceStream& stream() { + return stream_; + } + + bool has_gpu_work() const { + return has_gpu_work_; + } + + private: + Device& device_; + DeviceStream& stream_; + Worker worker_; + bool has_gpu_work_{false}; + std::vector> temporaries_; +}; + +Device& device(mlx::core::Device device); +DeviceStream& get_stream(Stream s); +CommandEncoder& get_command_encoder(Stream s); + +// Utility function to check HIP errors +void check_hip_error(const char* msg, hipError_t error); + +} // namespace rocm + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/device/arange.hpp b/mlx/backend/rocm/device/arange.hpp new file mode 100644 index 000000000..3bd28a0a0 --- /dev/null +++ b/mlx/backend/rocm/device/arange.hpp @@ -0,0 +1,17 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::rocm { + +template +__global__ void arange_kernel(T* out, T start, T step, size_t size) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < size) { + out[tid] = start + static_cast(tid) * step; + } +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/atomic_ops.hpp b/mlx/backend/rocm/device/atomic_ops.hpp new file mode 100644 index 000000000..4f924a170 --- /dev/null +++ b/mlx/backend/rocm/device/atomic_ops.hpp @@ -0,0 +1,36 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::rocm { + +// Atomic operations for HIP +__device__ inline float atomicAddFloat(float* address, float val) { + return atomicAdd(address, val); +} + +__device__ inline double atomicAddDouble(double* address, double val) { + return atomicAdd(address, val); +} + +__device__ inline int atomicAddInt(int* address, int val) { + return atomicAdd(address, val); +} + +__device__ inline unsigned int atomicAddUInt( + unsigned int* address, + unsigned int val) { + return atomicAdd(address, val); +} + +__device__ inline float atomicMaxFloat(float* address, float val) { + return atomicMax(address, val); +} + +__device__ inline float atomicMinFloat(float* address, float val) { + return atomicMin(address, val); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/binary_ops.hpp b/mlx/backend/rocm/device/binary_ops.hpp new file mode 100644 index 000000000..01766f2cc --- /dev/null +++ b/mlx/backend/rocm/device/binary_ops.hpp @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// Arithmetic operations +struct Add { + template + __device__ T operator()(T a, T b) { + return a + b; + } +}; + +struct Subtract { + template + __device__ T operator()(T a, T b) { + return a - b; + } +}; + +struct Multiply { + template + __device__ T operator()(T a, T b) { + return a * b; + } +}; + +struct Divide { + template + __device__ T operator()(T a, T b) { + return a / b; + } +}; + +struct Power { + template + __device__ T operator()(T a, T b) { + return powf(a, b); + } + + __device__ double operator()(double a, double b) { + return pow(a, b); + } +}; + +struct Remainder { + template + __device__ T operator()(T a, T b) { + return fmodf(a, b); + } + + __device__ double operator()(double a, double b) { + return fmod(a, b); + } +}; + +// Comparison operations +struct Equal { + template + __device__ bool operator()(T a, T b) { + return a == b; + } +}; + +struct NotEqual { + template + __device__ bool operator()(T a, T b) { + return a != b; + } +}; + +struct Greater { + template + __device__ bool operator()(T a, T b) { + return a > b; + } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T a, T b) { + return a >= b; + } +}; + +struct Less { + template + __device__ bool operator()(T a, T b) { + return a < b; + } +}; + +struct LessEqual { + template + __device__ bool operator()(T a, T b) { + return a <= b; + } +}; + +struct NaNEqual { + template + __device__ bool operator()(T a, T b) { + return (isnan(a) && isnan(b)) || (a == b); + } +}; + +// Logic operations +struct LogicalAnd { + __device__ bool operator()(bool a, bool b) { + return a && b; + } +}; + +struct LogicalOr { + __device__ bool operator()(bool a, bool b) { + return a || b; + } +}; + +// Math operations +struct Maximum { + template + __device__ T operator()(T a, T b) { + return fmaxf(a, b); + } + + __device__ double operator()(double a, double b) { + return fmax(a, b); + } +}; + +struct Minimum { + template + __device__ T operator()(T a, T b) { + return fminf(a, b); + } + + __device__ double operator()(double a, double b) { + return fmin(a, b); + } +}; + +struct LogAddExp { + template + __device__ T operator()(T a, T b) { + T max_val = fmaxf(a, b); + T min_val = fminf(a, b); + if (isinf(max_val)) { + return max_val; + } + return max_val + log1pf(expf(min_val - max_val)); + } + + __device__ double operator()(double a, double b) { + double max_val = fmax(a, b); + double min_val = fmin(a, b); + if (isinf(max_val)) { + return max_val; + } + return max_val + log1p(exp(min_val - max_val)); + } +}; + +struct ArcTan2 { + template + __device__ T operator()(T a, T b) { + return atan2f(a, b); + } + + __device__ double operator()(double a, double b) { + return atan2(a, b); + } +}; + +// Bitwise operations +struct BitwiseAnd { + template + __device__ T operator()(T a, T b) { + return a & b; + } +}; + +struct BitwiseOr { + template + __device__ T operator()(T a, T b) { + return a | b; + } +}; + +struct BitwiseXor { + template + __device__ T operator()(T a, T b) { + return a ^ b; + } +}; + +struct LeftShift { + template + __device__ T operator()(T a, T b) { + return a << b; + } +}; + +struct RightShift { + template + __device__ T operator()(T a, T b) { + return a >> b; + } +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/cast_op.hpp b/mlx/backend/rocm/device/cast_op.hpp new file mode 100644 index 000000000..593f61650 --- /dev/null +++ b/mlx/backend/rocm/device/cast_op.hpp @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::rocm { + +template +struct CastOp { + __device__ To operator()(From x) const { + return static_cast(x); + } +}; + +template +__device__ inline To cast_op(From x) { + return static_cast(x); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/config.h b/mlx/backend/rocm/device/config.h new file mode 100644 index 000000000..3eed48b57 --- /dev/null +++ b/mlx/backend/rocm/device/config.h @@ -0,0 +1,14 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +// ROCm/HIP specific configuration +#define ROCM_MAX_THREADS_PER_BLOCK 1024 +#define ROCM_WARP_SIZE 64 +#define ROCM_MAX_BLOCKS_PER_GRID 65535 + +namespace mlx::core::rocm { +constexpr int kMaxThreadsPerBlock = ROCM_MAX_THREADS_PER_BLOCK; +constexpr int kWarpSize = ROCM_WARP_SIZE; +constexpr int kMaxBlocksPerGrid = ROCM_MAX_BLOCKS_PER_GRID; +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/fp16_math.hpp b/mlx/backend/rocm/device/fp16_math.hpp new file mode 100644 index 000000000..f709bcb8b --- /dev/null +++ b/mlx/backend/rocm/device/fp16_math.hpp @@ -0,0 +1,87 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// HIP/ROCm equivalents of CUDA half precision math functions +inline __device__ __half2 h2sin(__half2 x) { + return __half2{hsin(x.x), hsin(x.y)}; +} + +inline __device__ __half2 h2cos(__half2 x) { + return __half2{hcos(x.x), hcos(x.y)}; +} + +inline __device__ __half2 h2exp(__half2 x) { + return __half2{hexp(x.x), hexp(x.y)}; +} + +inline __device__ __half2 h2log(__half2 x) { + return __half2{hlog(x.x), hlog(x.y)}; +} + +inline __device__ __half2 h2sqrt(__half2 x) { + return __half2{hsqrt(x.x), hsqrt(x.y)}; +} + +inline __device__ __half2 h2rsqrt(__half2 x) { + return __half2{hrsqrt(x.x), hrsqrt(x.y)}; +} + +inline __device__ __half2 h2ceil(__half2 x) { + return __half2{hceil(x.x), hceil(x.y)}; +} + +inline __device__ __half2 h2floor(__half2 x) { + return __half2{hfloor(x.x), hfloor(x.y)}; +} + +inline __device__ __half2 h2rint(__half2 x) { + return __half2{hrint(x.x), hrint(x.y)}; +} + +inline __device__ __half2 h2trunc(__half2 x) { + return __half2{htrunc(x.x), htrunc(x.y)}; +} + +// Additional math functions for half precision +inline __device__ __half habs(__half x) { + return __half{fabsf(__half2float(x))}; +} + +inline __device__ __half2 h2abs(__half2 x) { + return __half2{habs(x.x), habs(x.y)}; +} + +inline __device__ __half hneg(__half x) { + return __half{-__half2float(x)}; +} + +inline __device__ __half2 h2neg(__half2 x) { + return __half2{hneg(x.x), hneg(x.y)}; +} + +// BFloat16 support functions +#ifdef __HIP_BFLOAT16__ +inline __device__ __hip_bfloat16 habs(__hip_bfloat16 x) { + return __hip_bfloat16{fabsf(__bfloat162float(x))}; +} + +inline __device__ __hip_bfloat162 h2abs(__hip_bfloat162 x) { + return __hip_bfloat162{habs(x.x), habs(x.y)}; +} + +inline __device__ __hip_bfloat16 hneg(__hip_bfloat16 x) { + return __hip_bfloat16{-__bfloat162float(x)}; +} + +inline __device__ __hip_bfloat162 h2neg(__hip_bfloat162 x) { + return __hip_bfloat162{hneg(x.x), hneg(x.y)}; +} +#endif + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/hip_complex_math.hpp b/mlx/backend/rocm/device/hip_complex_math.hpp new file mode 100644 index 000000000..b35d00dae --- /dev/null +++ b/mlx/backend/rocm/device/hip_complex_math.hpp @@ -0,0 +1,52 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// HIP complex math functions +__device__ inline hipFloatComplex hip_complex_add( + hipFloatComplex a, + hipFloatComplex b) { + return make_hipFloatComplex( + hipCrealf(a) + hipCrealf(b), hipCimagf(a) + hipCimagf(b)); +} + +__device__ inline hipFloatComplex hip_complex_sub( + hipFloatComplex a, + hipFloatComplex b) { + return make_hipFloatComplex( + hipCrealf(a) - hipCrealf(b), hipCimagf(a) - hipCimagf(b)); +} + +__device__ inline hipFloatComplex hip_complex_mul( + hipFloatComplex a, + hipFloatComplex b) { + float real = hipCrealf(a) * hipCrealf(b) - hipCimagf(a) * hipCimagf(b); + float imag = hipCrealf(a) * hipCimagf(b) + hipCimagf(a) * hipCrealf(b); + return make_hipFloatComplex(real, imag); +} + +__device__ inline hipFloatComplex hip_complex_div( + hipFloatComplex a, + hipFloatComplex b) { + float denom = hipCrealf(b) * hipCrealf(b) + hipCimagf(b) * hipCimagf(b); + float real = + (hipCrealf(a) * hipCrealf(b) + hipCimagf(a) * hipCimagf(b)) / denom; + float imag = + (hipCimagf(a) * hipCrealf(b) - hipCrealf(a) * hipCimagf(b)) / denom; + return make_hipFloatComplex(real, imag); +} + +__device__ inline float hip_complex_abs(hipFloatComplex z) { + return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z)); +} + +__device__ inline hipFloatComplex hip_complex_conj(hipFloatComplex z) { + return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z)); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/ternary_ops.hpp b/mlx/backend/rocm/device/ternary_ops.hpp new file mode 100644 index 000000000..7a33c7599 --- /dev/null +++ b/mlx/backend/rocm/device/ternary_ops.hpp @@ -0,0 +1,16 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::rocm { + +struct Select { + template + __device__ T operator()(bool condition, T a, T b) const { + return condition ? a : b; + } +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/unary_ops.hpp b/mlx/backend/rocm/device/unary_ops.hpp new file mode 100644 index 000000000..266d50d7d --- /dev/null +++ b/mlx/backend/rocm/device/unary_ops.hpp @@ -0,0 +1,368 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/fp16_math.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core::rocm { + +struct Abs { + template + __device__ T operator()(T x) { + if constexpr (std::is_unsigned_v) { + return x; + } else if constexpr (std::is_same_v) { + return { + sqrt(hipCrealf(x) * hipCrealf(x) + hipCimagf(x) * hipCimagf(x)), 0}; + } else { + return abs(x); + } + } +}; + +struct ArcCos { + template + __device__ T operator()(T x) { + return acos(x); + } +}; + +struct ArcCosh { + template + __device__ T operator()(T x) { + return acosh(x); + } +}; + +struct ArcSin { + template + __device__ T operator()(T x) { + return asin(x); + } +}; + +struct ArcSinh { + template + __device__ T operator()(T x) { + return asinh(x); + } +}; + +struct ArcTan { + template + __device__ T operator()(T x) { + return atan(x); + } +}; + +struct ArcTanh { + template + __device__ T operator()(T x) { + return atanh(x); + } +}; + +struct BitwiseInvert { + template + __device__ T operator()(T x) { + return ~x; + } +}; + +struct Ceil { + template + __device__ T operator()(T x) { + if constexpr (std::is_integral_v) { + return x; + } else { + return ceil(x); + } + } +}; + +struct Conjugate { + __device__ hipFloatComplex operator()(hipFloatComplex x) { + return {hipCrealf(x), -hipCimagf(x)}; + } +}; + +struct Cos { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return { + cos(hipCrealf(x)) * cosh(hipCimagf(x)), + -sin(hipCrealf(x)) * sinh(hipCimagf(x))}; + } else { + return cos(x); + } + } +}; + +struct Cosh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return { + cosh(hipCrealf(x)) * cos(hipCimagf(x)), + sinh(hipCrealf(x)) * sin(hipCimagf(x))}; + } else { + return cosh(x); + } + } +}; + +struct Erf { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return erf(__half2float(x)); + } else if constexpr (std::is_same_v) { + return erf(__bfloat162float(x)); + } else { + return erf(x); + } + } +}; + +struct ErfInv { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return erfinv(__half2float(x)); + } else if constexpr (std::is_same_v) { + return erfinv(__bfloat162float(x)); + } else { + return erfinv(x); + } + } +}; + +struct Exp { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + auto m = exp(hipCrealf(x)); + return {m * cos(hipCimagf(x)), m * sinh(hipCimagf(x))}; + } else { + return exp(x); + } + } +}; + +struct Expm1 { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return expm1(__half2float(x)); + } else if constexpr (std::is_same_v) { + return expm1(__bfloat162float(x)); + } else { + return expm1(x); + } + } +}; + +struct Floor { + template + __device__ T operator()(T x) { + if constexpr (std::is_integral_v) { + return x; + } else { + return floor(x); + } + } +}; + +struct Imag { + __device__ float operator()(hipFloatComplex x) { + return hipCimagf(x); + } +}; + +struct Log { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + auto r = log(hipCrealf(Abs{}(x))); + auto i = atan2f(hipCimagf(x), hipCrealf(x)); + return {r, i}; + } else { + return log(x); + } + } +}; + +struct Log2 { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + auto y = Log{}(x); + return {hipCrealf(y) / M_LN2, hipCimagf(y) / M_LN2}; + } else { + return log2(x); + } + } +}; + +struct Log10 { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + auto y = Log{}(x); + return {hipCrealf(y) / M_LN10, hipCimagf(y) / M_LN10}; + } else { + return log10(x); + } + } +}; + +struct Log1p { + template + __device__ T operator()(T x) { + return log1p(x); + } +}; + +struct LogicalNot { + __device__ bool operator()(bool x) { + return !x; + } +}; + +struct Negative { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return 0 - x; + } else { + return -x; + } + } +}; + +struct Real { + __device__ float operator()(hipFloatComplex x) { + return hipCrealf(x); + } +}; + +struct Round { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return {rint(hipCrealf(x)), rint(hipCimagf(x))}; + } else { + return rint(x); + } + } +}; + +struct Rsqrt { + template + __device__ T operator()(T x) { + return rsqrt(x); + } +}; + +struct Sigmoid { + template + __device__ T operator()(T x) { + T y = 1 / (1 + exp(-abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + __device__ T operator()(T x) { + if constexpr (std::is_unsigned_v) { + return x != 0; + } else if constexpr (std::is_same_v) { + if (hipCrealf(x) == 0 && hipCimagf(x) == 0) { + return x; + } else { + return x / Abs()(x); + } + } else if constexpr (std::is_same_v) { + return static_cast((x > T(0.f)) - (x < T(0.f))); + } else { + return (x > T(0)) - (x < T(0)); + } + } +}; + +struct Sin { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return { + sin(hipCrealf(x)) * cosh(hipCimagf(x)), + cos(hipCrealf(x)) * sinh(hipCimagf(x))}; + } else { + return sin(x); + } + } +}; + +struct Sinh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return { + sinh(hipCrealf(x)) * cos(hipCimagf(x)), + cosh(hipCrealf(x)) * sin(hipCimagf(x))}; + } else { + return sinh(x); + } + } +}; + +struct Square { + template + __device__ T operator()(T x) { + return x * x; + } +}; + +struct Sqrt { + template + __device__ T operator()(T x) { + return sqrt(x); + } +}; + +struct Tan { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + float tan_a = tan(hipCrealf(x)); + float tanh_b = tanh(hipCimagf(x)); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + } else { + return tan(x); + } + } +}; + +struct Tanh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + float tanh_a = tanh(hipCrealf(x)); + float tan_b = tan(hipCimagf(x)); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + } else { + return tanh(x); + } + } +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device/utils.hpp b/mlx/backend/rocm/device/utils.hpp new file mode 100644 index 000000000..fc3833f72 --- /dev/null +++ b/mlx/backend/rocm/device/utils.hpp @@ -0,0 +1,173 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// HIP/ROCm type definitions +using hip_complex = hipFloatComplex; + +// Utility functions for HIP device code +template +struct hip_type { + using type = T; +}; + +template <> +struct hip_type { + using type = bool; +}; + +template <> +struct hip_type { + using type = int8_t; +}; + +template <> +struct hip_type { + using type = uint8_t; +}; + +template <> +struct hip_type { + using type = int16_t; +}; + +template <> +struct hip_type { + using type = uint16_t; +}; + +template <> +struct hip_type { + using type = int32_t; +}; + +template <> +struct hip_type { + using type = uint32_t; +}; + +template <> +struct hip_type { + using type = int64_t; +}; + +template <> +struct hip_type { + using type = uint64_t; +}; + +template <> +struct hip_type { + using type = float; +}; + +template <> +struct hip_type { + using type = double; +}; + +#ifdef __HIP_PLATFORM_HCC__ +template <> +struct hip_type<__half> { + using type = __half; +}; + +template <> +struct hip_type<__hip_bfloat16> { + using type = __hip_bfloat16; +}; +#endif + +template +using hip_type_t = typename hip_type::type; + +// Element-wise operations support +template +constexpr bool is_floating_point_v = std::is_floating_point_v; + +template +constexpr bool is_integral_v = std::is_integral_v; + +template +constexpr bool is_signed_v = std::is_signed_v; + +template +constexpr bool is_unsigned_v = std::is_unsigned_v; + +// Complex number helper functions +inline __device__ hipFloatComplex make_complex(float real, float imag) { + return make_hipFloatComplex(real, imag); +} + +inline __device__ float hip_real(hipFloatComplex z) { + return hipCrealf(z); +} + +inline __device__ float hip_imag(hipFloatComplex z) { + return hipCimagf(z); +} + +inline __device__ hipFloatComplex hip_conj(hipFloatComplex z) { + return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z)); +} + +inline __device__ float hip_abs(hipFloatComplex z) { + return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z)); +} + +// Memory access utilities +template +inline __device__ T hip_load_global(const T* ptr) { + return *ptr; +} + +template +inline __device__ void hip_store_global(T* ptr, T value) { + *ptr = value; +} + +// Grid and block utilities +inline __device__ int hip_thread_idx() { + return threadIdx.x; +} + +inline __device__ int hip_block_idx() { + return blockIdx.x; +} + +inline __device__ int hip_block_dim() { + return blockDim.x; +} + +inline __device__ int hip_grid_dim() { + return gridDim.x; +} + +inline __device__ int hip_global_thread_idx() { + return blockIdx.x * blockDim.x + threadIdx.x; +} + +// Synchronization +inline __device__ void hip_sync_threads() { + __syncthreads(); +} + +// Math constants for HIP (equivalent to CUDA's math_constants.h) +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +#ifndef M_LN2 +#define M_LN2 0.693147180559945309417 +#endif + +#ifndef M_LN10 +#define M_LN10 2.302585092994045684018 +#endif + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/eval.cpp b/mlx/backend/rocm/eval.cpp new file mode 100644 index 000000000..6fd43c668 --- /dev/null +++ b/mlx/backend/rocm/eval.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +void eval() { + // Placeholder for ROCm evaluation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/event.cpp b/mlx/backend/rocm/event.cpp new file mode 100644 index 000000000..a1ff81622 --- /dev/null +++ b/mlx/backend/rocm/event.cpp @@ -0,0 +1,50 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/event.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +HipEvent::HipEvent() { + CHECK_HIP_ERROR(hipEventCreate(&event_)); +} + +HipEvent::~HipEvent() { + CHECK_HIP_ERROR(hipEventDestroy(event_)); +} + +void HipEvent::record(hipStream_t stream) { + CHECK_HIP_ERROR(hipEventRecord(event_, stream)); +} + +void HipEvent::wait() { + CHECK_HIP_ERROR(hipEventSynchronize(event_)); +} + +bool HipEvent::query() const { + hipError_t status = hipEventQuery(event_); + if (status == hipSuccess) { + return true; + } else if (status == hipErrorNotReady) { + return false; + } else { + CHECK_HIP_ERROR(status); + return false; + } +} + +SharedEvent::SharedEvent() = default; + +void SharedEvent::notify() { + std::lock_guard lock(mutex_); + ready_ = true; + cv_.notify_one(); +} + +void SharedEvent::wait() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return ready_; }); + ready_ = false; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/event.h b/mlx/backend/rocm/event.h new file mode 100644 index 000000000..1a9d5f5a6 --- /dev/null +++ b/mlx/backend/rocm/event.h @@ -0,0 +1,48 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include +#include +#include + +namespace mlx::core::rocm { + +// HIP event managed with RAII. +class HipEvent { + public: + HipEvent(); + ~HipEvent(); + + HipEvent(const HipEvent&) = delete; + HipEvent& operator=(const HipEvent&) = delete; + + void record(hipStream_t stream); + void wait(); + bool query() const; + + operator hipEvent_t() const { + return event_; + } + + private: + hipEvent_t event_; +}; + +// Shared event for worker thread synchronization. +class SharedEvent { + public: + SharedEvent(); + + void notify(); + void wait(); + + private: + std::mutex mutex_; + std::condition_variable cv_; + bool ready_{false}; +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/event.hip b/mlx/backend/rocm/event.hip new file mode 100644 index 000000000..0358d9e6e --- /dev/null +++ b/mlx/backend/rocm/event.hip @@ -0,0 +1,32 @@ +// Copyright © 2025 Apple Inc. + +#include +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +class Event { +public: + Event() { + check_hip_error("hipEventCreate", hipEventCreate(&event_)); + } + + ~Event() { + hipEventDestroy(event_); + } + + void record(hipStream_t stream) { + check_hip_error("hipEventRecord", hipEventRecord(event_, stream)); + } + + void wait() { + check_hip_error("hipEventSynchronize", hipEventSynchronize(event_)); + } + + hipEvent_t event() const { return event_; } + +private: + hipEvent_t event_; +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/fence.cpp b/mlx/backend/rocm/fence.cpp new file mode 100644 index 000000000..d96c99c06 --- /dev/null +++ b/mlx/backend/rocm/fence.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void fence() { + // Placeholder for ROCm fence operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/indexing.cpp b/mlx/backend/rocm/indexing.cpp new file mode 100644 index 000000000..25e13c36b --- /dev/null +++ b/mlx/backend/rocm/indexing.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void index() { + // Placeholder for ROCm indexing operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/iterators/general_iterator.hpp b/mlx/backend/rocm/iterators/general_iterator.hpp new file mode 100644 index 000000000..ec3a84441 --- /dev/null +++ b/mlx/backend/rocm/iterators/general_iterator.hpp @@ -0,0 +1,153 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +template +struct GeneralIterator { + using difference_type = ptrdiff_t; + using value_type = IdxType; + using pointer = IdxType*; + using reference = IdxType&; + using iterator_category = std::random_access_iterator_tag; + + const IdxType* base_ptr; + IdxType offset; + const int* shape; + const size_t* strides; + int ndim; + size_t size; + + __device__ GeneralIterator( + const IdxType* base_ptr, + IdxType offset, + const int* shape, + const size_t* strides, + int ndim, + size_t size) + : base_ptr(base_ptr), + offset(offset), + shape(shape), + strides(strides), + ndim(ndim), + size(size) {} + + __device__ GeneralIterator operator+(difference_type n) const { + return GeneralIterator(base_ptr, offset + n, shape, strides, ndim, size); + } + + __device__ GeneralIterator operator-(difference_type n) const { + return GeneralIterator(base_ptr, offset - n, shape, strides, ndim, size); + } + + __device__ difference_type operator-(const GeneralIterator& other) const { + return offset - other.offset; + } + + __device__ GeneralIterator& operator+=(difference_type n) { + offset += n; + return *this; + } + + __device__ GeneralIterator& operator-=(difference_type n) { + offset -= n; + return *this; + } + + __device__ GeneralIterator& operator++() { + ++offset; + return *this; + } + + __device__ GeneralIterator operator++(int) { + GeneralIterator temp = *this; + ++offset; + return temp; + } + + __device__ GeneralIterator& operator--() { + --offset; + return *this; + } + + __device__ GeneralIterator operator--(int) { + GeneralIterator temp = *this; + --offset; + return temp; + } + + __device__ bool operator==(const GeneralIterator& other) const { + return offset == other.offset; + } + + __device__ bool operator!=(const GeneralIterator& other) const { + return offset != other.offset; + } + + __device__ bool operator<(const GeneralIterator& other) const { + return offset < other.offset; + } + + __device__ bool operator>(const GeneralIterator& other) const { + return offset > other.offset; + } + + __device__ bool operator<=(const GeneralIterator& other) const { + return offset <= other.offset; + } + + __device__ bool operator>=(const GeneralIterator& other) const { + return offset >= other.offset; + } + + __device__ IdxType operator*() const { + return base_ptr[elem_to_loc(offset, shape, strides, ndim)]; + } + + __device__ IdxType operator[](difference_type n) const { + return base_ptr[elem_to_loc(offset + n, shape, strides, ndim)]; + } + + private: + __device__ size_t elem_to_loc( + size_t elem, + const int* shape, + const size_t* strides, + int ndim) const { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + auto q_and_r = div(elem, static_cast(shape[i])); + loc += q_and_r.rem * strides[i]; + elem = q_and_r.quot; + } + return loc; + } + + __device__ div_t div(size_t numer, size_t denom) const { + div_t result; + result.quot = numer / denom; + result.rem = numer % denom; + return result; + } +}; + +template +__device__ std::pair, GeneralIterator> +make_general_iterators( + const IdxType* base_ptr, + size_t size, + const int* shape, + const size_t* strides, + int ndim) { + auto begin = + GeneralIterator(base_ptr, 0, shape, strides, ndim, size); + auto end = + GeneralIterator(base_ptr, size, shape, strides, ndim, size); + return std::make_pair(begin, end); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/iterators/strided_iterator.hpp b/mlx/backend/rocm/iterators/strided_iterator.hpp new file mode 100644 index 000000000..a4fd104a5 --- /dev/null +++ b/mlx/backend/rocm/iterators/strided_iterator.hpp @@ -0,0 +1,106 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +template +struct StridedIterator { + using difference_type = ptrdiff_t; + using value_type = T; + using pointer = T*; + using reference = T&; + using iterator_category = std::random_access_iterator_tag; + + T* ptr; + size_t stride; + + __device__ StridedIterator(T* ptr, size_t stride) + : ptr(ptr), stride(stride) {} + + __device__ StridedIterator operator+(difference_type n) const { + return StridedIterator(ptr + n * stride, stride); + } + + __device__ StridedIterator operator-(difference_type n) const { + return StridedIterator(ptr - n * stride, stride); + } + + __device__ difference_type operator-(const StridedIterator& other) const { + return (ptr - other.ptr) / stride; + } + + __device__ StridedIterator& operator+=(difference_type n) { + ptr += n * stride; + return *this; + } + + __device__ StridedIterator& operator-=(difference_type n) { + ptr -= n * stride; + return *this; + } + + __device__ StridedIterator& operator++() { + ptr += stride; + return *this; + } + + __device__ StridedIterator operator++(int) { + StridedIterator temp = *this; + ptr += stride; + return temp; + } + + __device__ StridedIterator& operator--() { + ptr -= stride; + return *this; + } + + __device__ StridedIterator operator--(int) { + StridedIterator temp = *this; + ptr -= stride; + return temp; + } + + __device__ bool operator==(const StridedIterator& other) const { + return ptr == other.ptr; + } + + __device__ bool operator!=(const StridedIterator& other) const { + return ptr != other.ptr; + } + + __device__ bool operator<(const StridedIterator& other) const { + return ptr < other.ptr; + } + + __device__ bool operator>(const StridedIterator& other) const { + return ptr > other.ptr; + } + + __device__ bool operator<=(const StridedIterator& other) const { + return ptr <= other.ptr; + } + + __device__ bool operator>=(const StridedIterator& other) const { + return ptr >= other.ptr; + } + + __device__ T& operator*() const { + return *ptr; + } + + __device__ T& operator[](difference_type n) const { + return *(ptr + n * stride); + } +}; + +template +__device__ StridedIterator make_strided_iterator(T* ptr, size_t stride) { + return StridedIterator(ptr, stride); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/jit_module.cpp b/mlx/backend/rocm/jit_module.cpp new file mode 100644 index 000000000..cdda490d5 --- /dev/null +++ b/mlx/backend/rocm/jit_module.cpp @@ -0,0 +1,167 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/jit_module.h" +#include "mlx/backend/rocm/utils.h" + +#include +#include +#include + +namespace mlx::core::rocm { + +JitModule::JitModule( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags, + bool verbose) { + compile(kernel_name, kernel_source, template_args, compiler_flags, verbose); +} + +JitModule::~JitModule() { + if (kernel_) { + // No hipFunctionDestroy equivalent in HIP + } + if (module_) { + CHECK_HIP_ERROR(hipModuleUnload(module_)); + } + if (program_) { + hiprtcDestroyProgram(&program_); + } +} + +void JitModule::compile( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags, + bool verbose) { + // Create HIPRTC program + CHECK_HIP_ERROR(hiprtcCreateProgram( + &program_, + kernel_source.c_str(), + kernel_name.c_str(), + 0, + nullptr, + nullptr)); + + // Build compiler options + std::vector options; + std::vector option_strings; + + // Add default options + option_strings.push_back("--std=c++17"); + option_strings.push_back("-O3"); + option_strings.push_back("-DMLX_USE_ROCM"); + + // Add user-provided flags + for (const auto& flag : compiler_flags) { + option_strings.push_back(flag); + } + + // Add template arguments + for (const auto& arg : template_args) { + option_strings.push_back("-D" + arg); + } + + // Convert to char* array + for (const auto& option : option_strings) { + options.push_back(option.c_str()); + } + + // Compile the program + hiprtcResult compile_result = + hiprtcCompileProgram(program_, options.size(), options.data()); + + // Get compilation log + size_t log_size; + CHECK_HIP_ERROR(hiprtcGetProgramLogSize(program_, &log_size)); + + if (log_size > 1) { + std::vector log(log_size); + CHECK_HIP_ERROR(hiprtcGetProgramLog(program_, log.data())); + + if (verbose || compile_result != HIPRTC_SUCCESS) { + fmt::print( + "HIPRTC compilation log for {}:\n{}\n", kernel_name, log.data()); + } + } + + if (compile_result != HIPRTC_SUCCESS) { + throw std::runtime_error( + fmt::format("HIPRTC compilation failed for kernel {}", kernel_name)); + } + + // Get compiled code + size_t code_size; + CHECK_HIP_ERROR(hiprtcGetCodeSize(program_, &code_size)); + + std::vector code(code_size); + CHECK_HIP_ERROR(hiprtcGetCode(program_, code.data())); + + // Load module + CHECK_HIP_ERROR(hipModuleLoadData(&module_, code.data())); + + // Get kernel function + CHECK_HIP_ERROR(hipModuleGetFunction(&kernel_, module_, kernel_name.c_str())); +} + +JitCache& JitCache::instance() { + static JitCache cache; + return cache; +} + +std::shared_ptr JitCache::get_or_create( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) { + std::string key = + make_key(kernel_name, kernel_source, template_args, compiler_flags); + + std::lock_guard lock(mutex_); + + auto it = cache_.find(key); + if (it != cache_.end()) { + if (auto module = it->second.lock()) { + return module; + } else { + cache_.erase(it); + } + } + + auto module = std::make_shared( + kernel_name, kernel_source, template_args, compiler_flags); + cache_[key] = module; + return module; +} + +std::string JitCache::make_key( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) const { + std::ostringstream oss; + oss << kernel_name << "|" << kernel_source; + + for (const auto& arg : template_args) { + oss << "|" << arg; + } + + for (const auto& flag : compiler_flags) { + oss << "|" << flag; + } + + return oss.str(); +} + +std::shared_ptr make_jit_kernel( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) { + return JitCache::instance().get_or_create( + kernel_name, kernel_source, template_args, compiler_flags); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/jit_module.h b/mlx/backend/rocm/jit_module.h new file mode 100644 index 000000000..55b655c4d --- /dev/null +++ b/mlx/backend/rocm/jit_module.h @@ -0,0 +1,100 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// JIT compilation module for ROCm +class JitModule { + public: + JitModule( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args = {}, + const std::vector& compiler_flags = {}, + bool verbose = false); + + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + + // Get the compiled kernel function + hipFunction_t get_kernel() const { + return kernel_; + } + + // Launch the kernel with given arguments + template + void launch( + dim3 grid_dims, + dim3 block_dims, + size_t shared_memory, + hipStream_t stream, + Args&&... args) { + void* kernel_args[] = {(void*)&args...}; + CHECK_HIP_ERROR(hipModuleLaunchKernel( + kernel_, + grid_dims.x, + grid_dims.y, + grid_dims.z, + block_dims.x, + block_dims.y, + block_dims.z, + shared_memory, + stream, + kernel_args, + nullptr)); + } + + private: + void compile( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags, + bool verbose); + + hiprtcProgram program_{nullptr}; + hipModule_t module_{nullptr}; + hipFunction_t kernel_{nullptr}; +}; + +// JIT cache for compiled modules +class JitCache { + public: + static JitCache& instance(); + + std::shared_ptr get_or_create( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args = {}, + const std::vector& compiler_flags = {}); + + private: + std::unordered_map> cache_; + std::mutex mutex_; + + std::string make_key( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) const; +}; + +// Helper function to create and cache JIT modules +std::shared_ptr make_jit_kernel( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args = {}, + const std::vector& compiler_flags = {}); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/kernel_utils.hip b/mlx/backend/rocm/kernel_utils.hip new file mode 100644 index 000000000..81b3be805 --- /dev/null +++ b/mlx/backend/rocm/kernel_utils.hip @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +// Utility functions for HIP kernels + +__device__ inline int get_global_id() { + return blockIdx.x * blockDim.x + threadIdx.x; +} + +__device__ inline int get_local_id() { + return threadIdx.x; +} + +__device__ inline int get_group_id() { + return blockIdx.x; +} + +__device__ inline int get_local_size() { + return blockDim.x; +} + +__device__ inline int get_num_groups() { + return gridDim.x; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/kernel_utils.hpp b/mlx/backend/rocm/kernel_utils.hpp new file mode 100644 index 000000000..f694fd008 --- /dev/null +++ b/mlx/backend/rocm/kernel_utils.hpp @@ -0,0 +1,135 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// Constants +constexpr int MAX_DIMS = 8; + +// HIP array type for passing arrays to kernels +template +using hip_array = std::array; + +// Helper to create hip_array from vector +template +__host__ hip_array make_hip_array(const std::vector& vec) { + hip_array arr; + for (int i = 0; i < N && i < vec.size(); ++i) { + arr[i] = vec[i]; + } + return arr; +} + +template +__host__ hip_array make_hip_array(const std::vector& vec) { + return make_hip_array(vec); +} + +// Type mapping from MLX types to HIP types +template +using hip_type_t = T; + +template <> +using hip_type_t = __half; + +template <> +using hip_type_t = __hip_bfloat16; + +template <> +using hip_type_t = hipFloatComplex; + +// Element to location mapping for general broadcasting +template +__device__ std::pair elem_to_loc_nd( + int64_t elem, + const int32_t* shape, + const int64_t* a_strides, + const int64_t* b_strides) { + int64_t a_idx = 0; + int64_t b_idx = 0; + + for (int i = NDIM - 1; i >= 0; --i) { + int64_t pos_in_dim = elem % shape[i]; + elem /= shape[i]; + a_idx += pos_in_dim * a_strides[i]; + b_idx += pos_in_dim * b_strides[i]; + } + + return {a_idx, b_idx}; +} + +// 4D specialization for performance +__device__ inline std::pair elem_to_loc_4d( + int64_t elem, + const int32_t* shape, + const int64_t* a_strides, + const int64_t* b_strides, + int ndim) { + int64_t a_idx = 0; + int64_t b_idx = 0; + + for (int i = ndim - 1; i >= 0; --i) { + int64_t pos_in_dim = elem % shape[i]; + elem /= shape[i]; + a_idx += pos_in_dim * a_strides[i]; + b_idx += pos_in_dim * b_strides[i]; + } + + return {a_idx, b_idx}; +} + +// Launch configuration calculation +template +std::pair +get_launch_args(Kernel kernel, const array& out, bool large = false) { + int threads_per_block = 256; + int64_t total_threads = out.size(); + + if (large) { + // For large arrays, use more blocks + int64_t blocks = + (total_threads + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } else { + int blocks = (total_threads + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } +} + +template +std::pair get_launch_args( + Kernel kernel, + int64_t size, + const std::vector& shape, + const std::vector& strides, + bool large = false) { + int threads_per_block = 256; + + if (large) { + int64_t blocks = (size + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } else { + int blocks = (size + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } +} + +// Cooperative groups thread rank equivalent +namespace cooperative_groups { +class grid_group { + public: + __device__ int64_t thread_rank() const { + return blockIdx.x * blockDim.x + threadIdx.x; + } +}; + +__device__ grid_group this_grid() { + return grid_group{}; +} +} // namespace cooperative_groups + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/layer_norm.hip b/mlx/backend/rocm/layer_norm.hip new file mode 100644 index 000000000..e0a50cf36 --- /dev/null +++ b/mlx/backend/rocm/layer_norm.hip @@ -0,0 +1,437 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/iterators/strided_iterator.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +namespace cg = cooperative_groups; + +inline __device__ float3 plus_f3(const float3& a, const float3& b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +// Similar to rocprim::BlockReduce, but result is broadcasted to every thread. +template +struct BlockBroadcastReduce { + static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); + static_assert(BLOCK_DIM % WARP_SIZE == 0); + using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + + cg::thread_block& block; + TempStorage& temp; + + template + __device__ T Reduce(const T& input, const Op& op, const T& init_value) { + auto warp = cg::tiled_partition(block); + T x = cg::reduce(warp, input, op); + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } + + __device__ T Sum(const T& input) { + return Reduce(input, hip_plus{}, T{}); + } +}; + +template +__global__ void layer_norm( + const T* x, + const T* w, + const T* b, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride, + int64_t b_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + x += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Sum. + float sum = 0; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + sum += static_cast(rocprim::thread_reduce(xn, hip_plus{})); + } + sum = BlockReduceT{block, temp}.Sum(sum); + + // Mean. + float mean = sum / axis_size; + + // Normalizer. + float normalizer = 0; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean); + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]) - mean; + normalizer += t * t; + } + } + normalizer = BlockReduceT{block, temp}.Sum(normalizer); + normalizer = rsqrt(normalizer / axis_size + eps); + + // Outputs. + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T bn[N_READS]; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size); + rocprim::block_load_direct_blocked(index, strided_iterator(b, b_stride), bn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float norm = (static_cast(xn[i]) - mean) * normalizer; + xn[i] = wn[i] * static_cast(norm) + bn[i]; + } + rocprim::block_store_direct_blocked(index, out, xn, axis_size); + } +} + +template +__global__ void layer_norm_vjp( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF = BlockBroadcastReduce; + using BlockReduceF3 = BlockBroadcastReduce; + __shared__ union { + typename BlockReduceF::TempStorage f; + typename BlockReduceF3::TempStorage f3; + } temp; + + x += grid.block_rank() * axis_size; + g += grid.block_rank() * axis_size; + gx += grid.block_rank() * axis_size; + gw += grid.block_rank() * axis_size; + + // Sum. + float sum = 0; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + sum += static_cast(rocprim::thread_reduce(xn, hip_plus{})); + } + sum = BlockReduceF{block, temp.f}.Sum(sum); + + // Mean. + float mean = sum / axis_size; + + // Normalizer. + float3 factors = {}; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T xn[N_READS]; + T wn[N_READS] = {}; + T gn[N_READS] = {}; + auto index = r * BLOCK_DIM + block.thread_rank(); + rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean); + rocprim::block_load_direct_blocked(index, g, gn, axis_size); + rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float t = static_cast(xn[i]) - mean; + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f3(factors, {wg, wg * t, t * t}); + } + } + factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {}); + float meanwg = factors.x / axis_size; + float meanwgxc = factors.y / axis_size; + float normalizer2 = 1 / (factors.z / axis_size + eps); + float normalizer = sqrt(normalizer2); + + // Outputs. + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T gn[N_READS]; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + rocprim::block_load_direct_blocked(index, g, gn, axis_size); + rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = (static_cast(xn[i]) - mean) * normalizer; + float wi = wn[i]; + float gi = gn[i]; + xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2; + if constexpr (HAS_W) { + wn[i] = gi * xi; + } + } + rocprim::block_store_direct_blocked(index, gx, xn, axis_size); + if constexpr (HAS_W) { + rocprim::block_store_direct_blocked(index, gw, wn, axis_size); + } + } +} + +// Utility functions +template +struct hip_plus { + __device__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +inline __device__ int hip_ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ inline auto strided_iterator(const T* ptr, int64_t stride) { + return ptr + stride; // Simplified strided iterator +} + +} // namespace rocm + +namespace fast { + +bool LayerNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +// TODO: There are duplicate code with backend/metal/normalization.cpp +void LayerNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + const array x = set_output(inputs[0]); + const array& w = inputs[1]; + const array& b = inputs[2]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, { + using DataType = hip_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = rocm::layer_norm; + hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream, + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); + }); + }); + }); +} + +void LayerNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + // Ensure row contiguity. We could relax this step by checking that the array + // is contiguous (no broadcasts or holes) and that the input strides are the + // same as the cotangent strides but for now this is simpler. + auto check_input = [&s](const array& x) -> std::pair { + if (x.flags().row_contiguous) { + return {x, false}; + } + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + return {x_copy, true}; + }; + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[3].is_donatable(); + auto [x, copied] = check_input(inputs[0]); + donate_x |= copied; + const array& w = inputs[1]; + const array& b = inputs[2]; + auto [g, g_copied] = check_input(inputs[3]); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + array& gb = outputs[2]; + + // Check whether we had a weight. + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs. + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(allocator::malloc(gx.nbytes())); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + if (has_w) { + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + encoder.add_temporary(gw_temp); + } + } + gw.set_data(allocator::malloc(gw.nbytes())); + gb.set_data(allocator::malloc(gb.nbytes())); + + // Finish with the gradient for b in case we had a b. + if (gb.ndim() == 1 && gb.size() == axis_size) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); + } + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, { + using DataType = hip_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BOOL(has_w, HAS_W, { + MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = rocm::layer_norm_vjp; + hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream, + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); + }); + + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core + +namespace mlx::core::rocm { + +__global__ void layer_norm_kernel( + float* input, + float* output, + float* gamma, + float* beta, + int n, + float eps) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < n) { + // Simplified layer norm placeholder + // Real implementation would compute mean and variance + output[idx] = gamma[idx] * input[idx] + beta[idx]; + } +} + +void launch_layer_norm( + float* input, + float* output, + float* gamma, + float* beta, + int n, + float eps, + hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(layer_norm_kernel, dim3(blocks), dim3(threads), 0, stream, + input, output, gamma, beta, n, eps); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/logsumexp.hip b/mlx/backend/rocm/logsumexp.hip new file mode 100644 index 000000000..94dfc6525 --- /dev/null +++ b/mlx/backend/rocm/logsumexp.hip @@ -0,0 +1,13 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void logsumexp_kernel(float* input, float* output, int n) { + // Placeholder implementation + int idx = blockIdx.x * blockDim.x + threadIdx.x; + (void)input; (void)output; (void)n; (void)idx; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/matmul.cpp b/mlx/backend/rocm/matmul.cpp new file mode 100644 index 000000000..9d6dbc065 --- /dev/null +++ b/mlx/backend/rocm/matmul.cpp @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +void matmul_hip( + float* a, + float* b, + float* c, + int m, + int n, + int k, + hipStream_t stream) { + // This is a placeholder - in a real implementation, this would use rocBLAS + // auto& device = get_current_device(); + // rocblas_sgemm(device.rocblas_handle(), ...); + + // For now, just a placeholder + (void)a; + (void)b; + (void)c; + (void)m; + (void)n; + (void)k; + (void)stream; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/no_rocm.cpp b/mlx/backend/rocm/no_rocm.cpp new file mode 100644 index 000000000..da686f59d --- /dev/null +++ b/mlx/backend/rocm/no_rocm.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/rocm.h" + +namespace mlx::core::rocm { + +bool is_available() { + return false; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/primitives.hip b/mlx/backend/rocm/primitives.hip new file mode 100644 index 000000000..c91e36da3 --- /dev/null +++ b/mlx/backend/rocm/primitives.hip @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/common/primitives.h" + +namespace mlx::core::rocm { + +// Basic kernel implementations will go here +// This is a placeholder for ROCm-specific primitive operations + +void add_hip() { + // Placeholder for HIP add operation +} + +void multiply_hip() { + // Placeholder for HIP multiply operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/random.hip b/mlx/backend/rocm/random.hip new file mode 100644 index 000000000..d192eb68d --- /dev/null +++ b/mlx/backend/rocm/random.hip @@ -0,0 +1,23 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void random_uniform_kernel(float* output, int n, unsigned int seed) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // Simple LCG placeholder - real implementation would use rocRAND + unsigned int state = seed + idx; + state = state * 1103515245 + 12345; + output[idx] = (float)(state & 0x7FFFFFFF) / (float)0x7FFFFFFF; + } +} + +void launch_random_uniform(float* output, int n, unsigned int seed, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(random_uniform_kernel, dim3(blocks), dim3(threads), 0, stream, output, n, seed); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/reduce.hip b/mlx/backend/rocm/reduce.hip new file mode 100644 index 000000000..6259e9a57 --- /dev/null +++ b/mlx/backend/rocm/reduce.hip @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void sum_reduce_kernel(float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Simple reduction placeholder + if (idx == 0) { + float sum = 0.0f; + for (int i = 0; i < n; i++) { + sum += input[i]; + } + output[0] = sum; + } +} + +void launch_sum_reduce(float* input, float* output, int n, hipStream_t stream) { + hipLaunchKernelGGL(sum_reduce_kernel, dim3(1), dim3(1), 0, stream, input, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/reduce/col_reduce.hip b/mlx/backend/rocm/reduce/col_reduce.hip new file mode 100644 index 000000000..66b779e12 --- /dev/null +++ b/mlx/backend/rocm/reduce/col_reduce.hip @@ -0,0 +1,311 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/cast_op.hpp" +#include "mlx/backend/rocm/reduce/reduce.hpp" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +namespace cg = cooperative_groups; + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size(); + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +template +__global__ void col_reduce_small( + const T* in, + U* out, + const ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + int column = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + if (column * N_READS >= args.reduction_stride) { + return; + } + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + // Read input to local. + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next( + block.thread_index().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + for (size_t r = block.thread_index().y; + r < args.non_col_reductions * args.reduction_size; + r += block.dim_threads().y) { + U vals[N_READS]; + rocprim::block_load_direct_blocked( + column, + make_cast_iterator(in + loop.location()), + vals, + args.reduction_stride, + ReduceInit::value()); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next( + block.dim_threads().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + } + + // Do block reduce when each column has more than 1 element to reduce. + if (block.dim_threads().y > 1) { + __shared__ U shared_vals[32 * 8 * N_READS]; + size_t col = + block.thread_index().y * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + shared_vals[col * N_READS + i] = totals[i]; + } + block.sync(); + if (block.thread_index().y == 0) { + for (int i = 0; i < N_READS; i++) { + totals[i] = shared_vals[block.thread_index().x * N_READS + i]; + } + for (int j = 1; j < block.dim_threads().y; j++) { + col = j * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + totals[i] = op(shared_vals[col * N_READS + i], totals[i]); + } + } + } + } + + // Write result. + if (block.thread_index().y == 0) { + rocprim::block_store_direct_blocked( + column, + out + out_idx * args.reduction_stride, + totals, + args.reduction_stride); + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4> +__global__ void col_reduce_looped( + const T* in, + U* out, + const ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int n_warps = BN / N_READS; + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + // Read input to local. + int r = block.thread_rank() / n_warps; + int column = block.thread_rank() % n_warps; + int in_offset = grid.block_index().x * BN; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); + for (; r < args.non_col_reductions * args.reduction_size; r += BM) { + U vals[N_READS]; + rocprim::block_load_direct_blocked( + column, + make_cast_iterator(in + loop.location() + in_offset), + vals, + args.reduction_stride - in_offset, + ReduceInit::value()); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / n_warps; + static_assert(BM == 32 && n_outputs == N_READS); + __shared__ U shared_vals[BM * BN]; + size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[col + i] = totals[i]; + } + block.sync(); + col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + for (int i = 0; i < n_outputs; i++) { + totals[i] = cg::reduce(warp, shared_vals[col + i], op); + } + + // Write result. + if (warp.thread_rank() == 0) { + size_t out_offset = grid.block_index().x * BN; + rocprim::block_store_direct_blocked( + warp.meta_group_rank(), + out + out_idx * args.reduction_stride + out_offset, + totals, + args.reduction_stride - out_offset); + } +} + +// Utility functions and templates +template +struct LoopedElemToLoc { + size_t location; + + __device__ LoopedElemToLoc(int reduce_ndim) : location(0) {} + + __device__ void next(size_t step, const int* shape, const size_t* strides) { + // Simplified implementation - actual would handle multi-dimensional indexing + location += step; + } +}; + +template +__device__ inline T* make_cast_iterator(const T* ptr) { + return const_cast(ptr); +} + +__device__ inline size_t elem_to_loc( + size_t elem, + const int* shape, + const size_t* strides, + int ndim) { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + size_t q = elem / shape[i]; + size_t r = elem % shape[i]; + loc += r * strides[i]; + elem = q; + } + return loc; +} + +} // namespace rocm + +inline auto output_grid_for_col_reduce( + const array& out, + const rocm::ColReduceArgs& args) { + auto out_shape = out.shape(); + auto out_strides = out.strides(); + while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { + out_shape.pop_back(); + out_strides.pop_back(); + } + return get_2d_grid_dims(out_shape, out_strides); +} + +void col_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + rocm::ColReduceArgs args(in, plan, axes); + + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + using InType = hip_type_t; + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using OutType = rocm::ReduceResult::type; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + constexpr int N_READS = 4; + dim3 block_dims; + dim3 num_blocks = output_grid_for_col_reduce(out, args); + num_blocks.z = num_blocks.y; + num_blocks.y = num_blocks.x; + auto kernel = + rocm::col_reduce_small; + size_t total = args.non_col_reductions * args.reduction_size; + if (total < 32) { + size_t stride_blocks = + hip_ceil_div(args.reduction_stride, N_READS); + block_dims.x = std::min(stride_blocks, 32ul); + block_dims.y = std::min(total, 8ul); + num_blocks.x = hip_ceil_div(stride_blocks, block_dims.x); + } else { + constexpr int BM = 32; + constexpr int BN = 32; + block_dims.x = BM * BN / N_READS; + num_blocks.x = hip_ceil_div(args.reduction_stride, BN); + kernel = rocm:: + col_reduce_looped; + } + hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream, + in.data(), out.data(), args); + }); + }); + }); + }); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/reduce/reduce.hpp b/mlx/backend/rocm/reduce/reduce.hpp new file mode 100644 index 000000000..87894b3dd --- /dev/null +++ b/mlx/backend/rocm/reduce/reduce.hpp @@ -0,0 +1,119 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// Reduction operation types +template +struct ReduceInit { + static constexpr T value(); +}; + +template +struct ReduceInit { + static constexpr T value() { + return T(0); + } +}; + +template +struct ReduceInit { + static constexpr T value() { + return -std::numeric_limits::infinity(); + } +}; + +template +struct ReduceInit { + static constexpr T value() { + return std::numeric_limits::infinity(); + } +}; + +// Reduction operations +struct Sum { + template + __device__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct Max { + template + __device__ T operator()(T a, T b) const { + return fmax(a, b); + } +}; + +struct Min { + template + __device__ T operator()(T a, T b) const { + return fmin(a, b); + } +}; + +struct Prod { + template + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Utility functions for reductions +template +__device__ T warp_reduce(T val, T (*op)(T, T)) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val = op(val, __shfl_down(val, offset)); + } + return val; +} + +template +__device__ T block_reduce(T val, T (*op)(T, T)) { + static __shared__ T shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = warp_reduce(val, op); + + if (lane == 0) + shared[wid] = val; + __syncthreads(); + + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; + if (wid == 0) + val = warp_reduce(val, op); + + return val; +} + +// Column reduction arguments +struct ColReduceArgs { + size_t reduction_size; + int64_t reduction_stride; + int* shape; + size_t* strides; + int ndim; + int* reduce_shape; + size_t* reduce_strides; + int reduce_ndim; + size_t non_col_reductions; +}; + +// Row reduction arguments +struct RowReduceArgs { + size_t reduction_size; + int64_t reduction_stride; + int* shape; + size_t* strides; + int ndim; + int* reduce_shape; + size_t* reduce_strides; + int reduce_ndim; +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/rms_norm.hip b/mlx/backend/rocm/rms_norm.hip new file mode 100644 index 000000000..e58e306d1 --- /dev/null +++ b/mlx/backend/rocm/rms_norm.hip @@ -0,0 +1,375 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/iterators/strided_iterator.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +namespace cg = cooperative_groups; + +// Similar to rocprim::BlockReduce, but result is broadcasted to every thread. +template +struct BlockBroadcastReduce { + static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); + static_assert(BLOCK_DIM % WARP_SIZE == 0); + using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + + cg::thread_block& block; + TempStorage& temp; + + template + __device__ T Reduce(const T& input, const Op& op, const T& init_value) { + auto warp = cg::tiled_partition(block); + T x = cg::reduce(warp, input, op); + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } + + __device__ T Sum(const T& input) { + return Reduce(input, hip_plus{}, T{}); + } +}; + +template +__global__ void rms_norm( + const T* x, + const T* w, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + x += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Sum of squares. + float sum_sq = 0; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float val = static_cast(xn[i]); + sum_sq += val * val; + } + } + sum_sq = BlockReduceT{block, temp}.Sum(sum_sq); + + // RMS normalizer. + float rms_normalizer = rsqrt(sum_sq / axis_size + eps); + + // Outputs. + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float norm = static_cast(xn[i]) * rms_normalizer; + xn[i] = wn[i] * static_cast(norm); + } + rocprim::block_store_direct_blocked(index, out, xn, axis_size); + } +} + +template +__global__ void rms_norm_vjp( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF = BlockBroadcastReduce; + using BlockReduceF2 = BlockBroadcastReduce; + __shared__ union { + typename BlockReduceF::TempStorage f; + typename BlockReduceF2::TempStorage f2; + } temp; + + x += grid.block_rank() * axis_size; + g += grid.block_rank() * axis_size; + gx += grid.block_rank() * axis_size; + gw += grid.block_rank() * axis_size; + + // Sum of squares. + float sum_sq = 0; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float val = static_cast(xn[i]); + sum_sq += val * val; + } + } + sum_sq = BlockReduceF{block, temp.f}.Sum(sum_sq); + + // RMS normalizer. + float rms_normalizer = rsqrt(sum_sq / axis_size + eps); + + // Compute gradient terms. + float2 factors = {}; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T xn[N_READS]; + T wn[N_READS] = {}; + T gn[N_READS] = {}; + auto index = r * BLOCK_DIM + block.thread_rank(); + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + rocprim::block_load_direct_blocked(index, g, gn, axis_size); + rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = static_cast(xn[i]); + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors.x += wg; + factors.y += wg * xi; + } + } + auto plus_f2 = [] __device__ (const float2& a, const float2& b) -> float2 { + return {a.x + b.x, a.y + b.y}; + }; + factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {}); + float mean_wg = factors.x / axis_size; + float mean_wgx = factors.y / axis_size; + float rms3 = rms_normalizer * rms_normalizer * rms_normalizer; + + // Outputs. + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T gn[N_READS]; + rocprim::block_load_direct_blocked(index, x, xn, axis_size); + rocprim::block_load_direct_blocked(index, g, gn, axis_size); + rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = static_cast(xn[i]); + float wi = wn[i]; + float gi = gn[i]; + float norm = xi * rms_normalizer; + xn[i] = rms_normalizer * (wi * gi - mean_wg) - norm * mean_wgx * rms3; + if constexpr (HAS_W) { + wn[i] = gi * norm; + } + } + rocprim::block_store_direct_blocked(index, gx, xn, axis_size); + if constexpr (HAS_W) { + rocprim::block_store_direct_blocked(index, gw, wn, axis_size); + } + } +} + +// Utility functions +template +struct hip_plus { + __device__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +inline __device__ int hip_ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ inline auto strided_iterator(const T* ptr, int64_t stride) { + return ptr + stride; // Simplified strided iterator +} + +} // namespace rocm + +namespace fast { + +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void RMSNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + const array x = set_output(inputs[0]); + const array& w = inputs[1]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rmsnorm", CTYPE, { + using DataType = hip_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = rocm::rms_norm; + hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream, + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); +} + +void RMSNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + // Ensure row contiguity. We could relax this step by checking that the array + // is contiguous (no broadcasts or holes) and that the input strides are the + // same as the cotangent strides but for now this is simpler. + auto check_input = [&s](const array& x) -> std::pair { + if (x.flags().row_contiguous) { + return {x, false}; + } + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + return {x_copy, true}; + }; + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[2].is_donatable(); + auto [x, copied] = check_input(inputs[0]); + donate_x |= copied; + const array& w = inputs[1]; + auto [g, g_copied] = check_input(inputs[2]); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + + // Check whether we had a weight. + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs. + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(allocator::malloc(gx.nbytes())); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + if (has_w) { + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + encoder.add_temporary(gw_temp); + } + } + gw.set_data(allocator::malloc(gw.nbytes())); + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rmsnorm_vjp", CTYPE, { + using DataType = hip_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BOOL(has_w, HAS_W, { + MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = rocm::rms_norm_vjp; + hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream, + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); + }); + + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/rocm.cpp b/mlx/backend/rocm/rocm.cpp new file mode 100644 index 000000000..83548423a --- /dev/null +++ b/mlx/backend/rocm/rocm.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/rocm.h" + +namespace mlx::core::rocm { + +bool is_available() { + return true; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/rocm.h b/mlx/backend/rocm/rocm.h new file mode 100644 index 000000000..8cc6be67d --- /dev/null +++ b/mlx/backend/rocm/rocm.h @@ -0,0 +1,10 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::rocm { + +/* Check if the ROCm backend is available. */ +bool is_available(); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/rope.hip b/mlx/backend/rocm/rope.hip new file mode 100644 index 000000000..89ea8279a --- /dev/null +++ b/mlx/backend/rocm/rope.hip @@ -0,0 +1,383 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +template +__device__ void rope_single_impl( + const T* in, + T* out, + int32_t offset, + float inv_freq, + float scale, + int64_t stride, + uint2 pos, + uint2 dims) { + float L = scale * static_cast(offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + uint index_1, index_2; + if (traditional) { + index_1 = 2 * pos.x + pos.y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos.x + pos.y * stride; + index_2 = index_1 + dims.x; + } + + // Read and write the output + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +template +__global__ void rope_single( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + int64_t stride, + uint2 dims) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__global__ void rope_single_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + int64_t stride, + uint2 dims, + int64_t freq_stride) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__device__ void rope_impl( + const T* in, + T* out, + int offset, + float inv_freq, + float scale, + const hip_array strides, + const hip_array out_strides, + int64_t n_batch, + uint3 pos, + uint3 dims) { + float L = scale * static_cast(pos.y + offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + size_t in_index_1, in_index_2; + size_t out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + 1; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + dims.x * out_strides[2]; + in_index_1 = + pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + dims.x * strides[2]; + } + for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); + in_index_1 += strides[0]; + in_index_2 += strides[0]; + out_index_1 += out_strides[0]; + out_index_2 += out_strides[0]; + } +} + +template +__global__ void rope( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + const hip_array strides, + const hip_array out_strides, + int64_t n_batch, + uint3 dims) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +template +__global__ void rope_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + float base, + const hip_array strides, + const hip_array out_strides, + int64_t n_batch, + uint3 dims, + int64_t freq_stride) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +} // namespace rocm + +namespace fast { + +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& in = inputs[0]; + auto& offset = inputs[1]; + auto& out = outputs[0]; + + if (in.ndim() < 3) { + throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); + } + + hip_array strides; + hip_array out_strides; + bool donated = false; + int ndim = in.ndim(); + int dispatch_ndim = in.ndim(); + while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { + dispatch_ndim--; + } + size_t mat_size = in.shape(-2) * in.shape(-1); + + // We apply rope to less that the whole vector so copy to output and then + // apply in-place. + if (dims_ < in.shape(-1)) { + donated = true; + auto ctype = + (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + + // Either copy or apply in-place + else if (in.flags().row_contiguous) { + if (in.is_donatable()) { + donated = true; + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + strides[0] = mat_size; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else if (dispatch_ndim == 3) { + // Handle non-contiguous 3D inputs + out.set_data(allocator::malloc(out.nbytes())); + strides[0] = in.strides()[ndim - 3]; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else { + // Copy non-contiguous > 3D inputs into the output and treat + // input as donated + donated = true; + copy_gpu(in, out, CopyType::General, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; + + // Some flags to help us dispatch below + bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + bool with_freqs = inputs.size() == 3; + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(donated ? out : in); + encoder.set_input_array(offset); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, { + using DataType = hip_type_t; + MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { + MLX_SWITCH_BOOL(forward_, FORWARD, { + if (single && !with_freqs) { + auto kernel = rocm::rope_single; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + hipLaunchKernelGGL(kernel, grid, block, 0, stream, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = rocm::rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + hipLaunchKernelGGL(kernel, grid, block, 0, stream, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = rocm::rope_freqs; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + hipLaunchKernelGGL(kernel, grid, block, 0, stream, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims, + inputs[2].strides(0)); + } else { + auto kernel = rocm::rope; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + hipLaunchKernelGGL(kernel, grid, block, 0, stream, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims); + } + }); + }); + }); + }); +} + +} // namespace fast + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/slicing.cpp b/mlx/backend/rocm/slicing.cpp new file mode 100644 index 000000000..2d5c3e54a --- /dev/null +++ b/mlx/backend/rocm/slicing.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void slice() { + // Placeholder for ROCm slicing operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/softmax.hip b/mlx/backend/rocm/softmax.hip new file mode 100644 index 000000000..8799c4498 --- /dev/null +++ b/mlx/backend/rocm/softmax.hip @@ -0,0 +1,179 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/cast_op.hpp" +#include "mlx/backend/rocm/device/fp16_math.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace rocm { + +namespace cg = cooperative_groups; + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return __expf(x); +} + +template +__global__ void softmax(const T* in, T* out, int axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Thread reduce. + AccT prevmax; + AccT maxval = -INFINITY; + AccT normalizer = 0; + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + AccT vals[N_READS]; + rocprim::block_load_direct_blocked( + r * BLOCK_DIM + block.thread_rank(), + make_cast_iterator(in), + vals, + axis_size, + -INFINITY); + prevmax = maxval; + maxval = fmax(maxval, rocprim::thread_reduce(vals, hip_max())); + // Online normalizer calculation for softmax: + // https://github.com/NVIDIA/online-softmax + normalizer = normalizer * softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce. + prevmax = maxval; + maxval = cg::reduce(warp, maxval, hip_max()); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = cg::reduce(warp, normalizer, hip_plus()); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce. + prevmax = maxval; + if (warp.thread_rank() == 0) { + local_max[warp.meta_group_rank()] = maxval; + } + block.sync(); + maxval = warp.thread_rank() < warp.meta_group_size() + ? local_max[warp.thread_rank()] + : -INFINITY; + maxval = cg::reduce(warp, maxval, hip_max()); + normalizer = normalizer * softmax_exp(prevmax - maxval); + if (warp.thread_rank() == 0) { + local_normalizer[warp.meta_group_rank()] = normalizer; + } + block.sync(); + normalizer = warp.thread_rank() < warp.meta_group_size() + ? local_normalizer[warp.thread_rank()] + : AccT{}; + normalizer = cg::reduce(warp, normalizer, hip_plus()); + normalizer = 1 / normalizer; + + // Write output. + for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T vals[N_READS]; + rocprim::block_load_direct_blocked(index, in, vals, axis_size); + for (int i = 0; i < N_READS; i++) { + vals[i] = softmax_exp(static_cast(vals[i]) - maxval) * normalizer; + } + rocprim::block_store_direct_blocked(index, out, vals, axis_size); + } +} + +// Utility functions for ROCm +template +struct hip_max { + __device__ T operator()(const T& a, const T& b) const { + return fmax(a, b); + } +}; + +template +struct hip_plus { + __device__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +inline __device__ int hip_ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ inline T* make_cast_iterator(const T* ptr) { + return const_cast(ptr); +} + +} // namespace rocm + +void Softmax::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& s = stream(); + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + array in = set_output(inputs[0]); + bool precise = in.dtype() != float32 && precise_; + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { + using DataType = hip_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = rocm::softmax; + if (precise) { + kernel = rocm::softmax; + } + hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream, + in.data(), out.data(), axis_size); + }); + }); + }); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/sort.hip b/mlx/backend/rocm/sort.hip new file mode 100644 index 000000000..b694a7f8a --- /dev/null +++ b/mlx/backend/rocm/sort.hip @@ -0,0 +1,178 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +template +struct ModOp { + T divisor; + __device__ T operator()(T x) { + return x % divisor; + } +}; + +// We can not use any op in eval, make an utility. +array swapaxes_in_eval(const array& in, int axis1, int axis2) { + std::vector axes(in.ndim()); + std::iota(axes.begin(), axes.end(), 0); + std::swap(axes[axis1], axes[axis2]); + // TODO: Share the code with Transpose::eval. + Shape shape(axes.size()); + Strides strides(in.ndim()); + for (size_t ax = 0; ax < axes.size(); ++ax) { + shape[ax] = in.shape()[axes[ax]]; + strides[ax] = in.strides()[axes[ax]]; + } + auto flags = in.flags(); + if (flags.contiguous) { + auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides); + flags.row_contiguous = row_contiguous; + flags.col_contiguous = col_contiguous; + } + array out(shape, in.dtype(), nullptr, {}); + out.copy_shared_buffer(in, strides, flags, in.data_size()); + return out; +} + +template +void segmented_sort_pairs(rocm::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_HIP_ERROR( + rocprim::segmented_sort_pairs(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_HIP_ERROR(rocprim::segmented_sort_pairs( + temp.data(), size, args...)); +} + +template +void segmented_sort(rocm::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_HIP_ERROR( + rocprim::segmented_sort_keys(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_HIP_ERROR(rocprim::segmented_sort_keys( + temp.data(), size, args...)); +} + +void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { + array out = out_; + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + if (axis < 0) { + axis += in.ndim(); + } + int nsort = in.shape(axis); + int nsegments = in.data_size() / nsort; + int last_dim = in.ndim() - 1; + + // If we are not sorting the innermost dimension of a contiguous array, + // transpose and make a copy. + bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; + if (!is_segmented_sort) { + array trans = swapaxes_in_eval(in, axis, last_dim); + in = array(trans.shape(), trans.dtype(), nullptr, {}); + copy_gpu(trans, in, CopyType::General, s); + encoder.add_temporary(in); + out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(out); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + if constexpr (!std::is_same_v) { + using Type = hip_type_t; + auto offsets = rocthrust::make_transform_iterator( + rocthrust::make_counting_iterator(0), + [nsort] __device__(int i) { return i * nsort; }); + if (argsort) { + // Indices in the sorted dimension. + array indices( + allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(indices); + rocthrust::transform( + rocm::thrust_policy(stream), + rocthrust::counting_iterator(0), + rocthrust::counting_iterator(indices.data_size()), + rocthrust::device_pointer_cast(indices.data()), + ModOp{static_cast(nsort)}); + + // In argsort though we don't need the result of sorted values, the + // API requires us to provide an array to store it. + array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + encoder.add_temporary(discard); + + segmented_sort_pairs( + encoder, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + nsegments, + offsets, + offsets + 1, + stream); + } else { + segmented_sort( + encoder, + in.data(), + out.data(), + in.data_size(), + nsegments, + offsets, + offsets + 1, + stream); + } + } else { + throw std::runtime_error( + "ROCm backend does not support sorting complex numbers"); + } + }); + }); + + if (!is_segmented_sort) { + // Swap the sorted axis back. + // TODO: Do in-place transpose instead of using a temporary out array. + copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); + } +} + +} // namespace + +void ArgSort::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Sort::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/ternary.hip b/mlx/backend/rocm/ternary.hip new file mode 100644 index 000000000..57c5d02a7 --- /dev/null +++ b/mlx/backend/rocm/ternary.hip @@ -0,0 +1,148 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/ternary_ops.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +template +constexpr bool supports_ternary_op() { + if (std::is_same_v) { + return std::is_same_v && std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace rocm + +template +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + auto& condition = inputs[0]; + auto& a = inputs[1]; + auto& b = inputs[2]; + + if (condition.size() == 0) { + return; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(condition); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_ALL_TYPES(condition.dtype(), CONDITION_TYPE, { + MLX_SWITCH_ALL_TYPES(a.dtype(), A_TYPE, { + MLX_SWITCH_ALL_TYPES(b.dtype(), B_TYPE, { + MLX_SWITCH_ALL_TYPES(out.dtype(), OUT_TYPE, { + if constexpr (rocm::supports_ternary_op()) { + using ConditionType = hip_type_t; + using AType = hip_type_t; + using BType = hip_type_t; + using OutType = hip_type_t; + + auto policy = rocm::thrust_policy(stream); + auto condition_ptr = rocthrust::device_pointer_cast(condition.data()); + auto a_ptr = rocthrust::device_pointer_cast(a.data()); + auto b_ptr = rocthrust::device_pointer_cast(b.data()); + auto out_ptr = rocthrust::device_pointer_cast(out.data()); + + if (condition.flags().contiguous && a.flags().contiguous && b.flags().contiguous) { + auto ternary_op = [=] __device__ (const auto& tuple) -> OutType { + return Op{}(rocthrust::get<0>(tuple), rocthrust::get<1>(tuple), rocthrust::get<2>(tuple)); + }; + + auto zip_begin = rocthrust::make_zip_iterator( + rocthrust::make_tuple(condition_ptr, a_ptr, b_ptr)); + auto zip_end = rocthrust::make_zip_iterator( + rocthrust::make_tuple(condition_ptr + condition.data_size(), + a_ptr + a.data_size(), + b_ptr + b.data_size())); + + rocthrust::transform(policy, zip_begin, zip_end, out_ptr, ternary_op); + } else { + // Handle non-contiguous arrays with general iterators + auto [condition_shape, condition_strides] = collapse_contiguous_dims(condition); + auto [a_shape, a_strides] = collapse_contiguous_dims(a); + auto [b_shape, b_strides] = collapse_contiguous_dims(b); + + auto [condition_begin, condition_end] = rocm::make_general_iterators( + condition_ptr, condition.size(), condition_shape, condition_strides); + auto [a_begin, a_end] = rocm::make_general_iterators( + a_ptr, a.size(), a_shape, a_strides); + auto [b_begin, b_end] = rocm::make_general_iterators( + b_ptr, b.size(), b_shape, b_strides); + + auto ternary_op = [=] __device__ (const auto& tuple) -> OutType { + return Op{}(rocthrust::get<0>(tuple), rocthrust::get<1>(tuple), rocthrust::get<2>(tuple)); + }; + + auto zip_begin = rocthrust::make_zip_iterator( + rocthrust::make_tuple(condition_begin, a_begin, b_begin)); + auto zip_end = rocthrust::make_zip_iterator( + rocthrust::make_tuple(condition_end, a_end, b_end)); + + rocthrust::transform(policy, zip_begin, zip_end, out_ptr, ternary_op); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do ternary op {} on inputs of {}, {}, {} with output of {}.", + op, + dtype_to_string(condition.dtype()), + dtype_to_string(a.dtype()), + dtype_to_string(b.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); + }); + }); +} + +template +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + set_ternary_output_data(inputs, out); + ternary_op_gpu_inplace(inputs, out, op, s); +} + +void Select::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + ternary_op_gpu(inputs, out, get_primitive_string(this), s); +} + +} // namespace mlx::core + +__global__ void select_kernel(float* condition, float* a, float* b, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = (condition[idx] != 0.0f) ? a[idx] : b[idx]; + } +} + +void launch_select(float* condition, float* a, float* b, float* output, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(select_kernel, dim3(blocks), dim3(threads), 0, stream, condition, a, b, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/unary.hip b/mlx/backend/rocm/unary.hip new file mode 100644 index 000000000..24f94177f --- /dev/null +++ b/mlx/backend/rocm/unary.hip @@ -0,0 +1,222 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/hip_complex_math.hpp" +#include "mlx/backend/rocm/device/unary_ops.hpp" +#include "mlx/backend/rocm/iterators/general_iterator.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +template +constexpr bool supports_unary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && !std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v && + (is_floating_v || std::is_same_v); + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace rocm + +template +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + auto& in = inputs[0]; + if (in.size() == 0) { + return; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (rocm::supports_unary_op()) { + using InType = hip_type_t; + using OutType = hip_type_t; + auto policy = rocm::thrust_policy(stream); + auto in_ptr = rocthrust::device_pointer_cast(in.data()); + auto out_ptr = rocthrust::device_pointer_cast(out.data()); + if (in.flags().contiguous) { + rocthrust::transform( + policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); + } else { + auto [shape, strides] = collapse_contiguous_dims(in); + auto [in_begin, in_end] = rocm::make_general_iterators( + in_ptr, in.size(), shape, strides); + rocthrust::transform(policy, in_begin, in_end, out_ptr, Op()); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void unary_op_gpu( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + set_unary_output_data(inputs[0], out); + unary_op_gpu_inplace(inputs, out, op, s); +} + +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +UNARY_GPU(Abs) +UNARY_GPU(ArcCos) +UNARY_GPU(ArcCosh) +UNARY_GPU(ArcSin) +UNARY_GPU(ArcSinh) +UNARY_GPU(ArcTan) +UNARY_GPU(ArcTanh) +UNARY_GPU(BitwiseInvert) +UNARY_GPU(Ceil) +UNARY_GPU(Conjugate) +UNARY_GPU(Cos) +UNARY_GPU(Cosh) +UNARY_GPU(Erf) +UNARY_GPU(ErfInv) +UNARY_GPU(Exp) +UNARY_GPU(Expm1) +UNARY_GPU(Floor) +UNARY_GPU(Imag) +UNARY_GPU(Log1p) +UNARY_GPU(LogicalNot) +UNARY_GPU(Negative) +UNARY_GPU(Real) +UNARY_GPU(Sigmoid) +UNARY_GPU(Sign) +UNARY_GPU(Sin) +UNARY_GPU(Sinh) +UNARY_GPU(Square) +UNARY_GPU(Tan) +UNARY_GPU(Tanh) + +void Log::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (base_) { + case Base::e: + unary_op_gpu(inputs, out, op, s); + break; + case Base::two: + unary_op_gpu(inputs, out, op, s); + break; + case Base::ten: + unary_op_gpu(inputs, out, op, s); + break; + } +} + +void Round::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto& s = out.primitive().stream(); + if (issubdtype(in.dtype(), inexact)) { + unary_op_gpu(inputs, out, get_primitive_string(this), s); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + if (recip_) { + unary_op_gpu(inputs, out, "Rsqrt", s); + } else { + unary_op_gpu(inputs, out, "Sqrt", s); + } +} + +} // namespace mlx::core + +__global__ void relu_kernel(float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = fmaxf(0.0f, input[idx]); + } +} + +__global__ void sigmoid_kernel(float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = 1.0f / (1.0f + expf(-input[idx])); + } +} + +void launch_relu(float* input, float* output, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(relu_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n); +} + +void launch_sigmoid(float* input, float* output, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(sigmoid_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/utils.cpp b/mlx/backend/rocm/utils.cpp new file mode 100644 index 000000000..1d4668b96 --- /dev/null +++ b/mlx/backend/rocm/utils.cpp @@ -0,0 +1,46 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +HipStream::HipStream(rocm::Device& device) { + device.make_current(); + CHECK_HIP_ERROR(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); +} + +HipStream::~HipStream() { + CHECK_HIP_ERROR(hipStreamDestroy(stream_)); +} + +void check_hip_error(const char* name, hipError_t err) { + if (err != hipSuccess) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, hipGetErrorString(err))); + } +} + +const char* dtype_to_hip_type(const Dtype& dtype) { + if (dtype == float16) { + return "__half"; + } + if (dtype == bfloat16) { + return "__hip_bfloat16"; + } + if (dtype == complex64) { + return "hipFloatComplex"; + } +#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ + if (dtype == DTYPE) { \ + return #CPP_TYPE; \ + } + MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) +#undef SPECIALIZE_DtypeToString + return nullptr; +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/utils.h b/mlx/backend/rocm/utils.h new file mode 100644 index 000000000..679828896 --- /dev/null +++ b/mlx/backend/rocm/utils.h @@ -0,0 +1,43 @@ +// Copyright © 2025 Apple Inc. + +// This file includes utilities that are used by C++ code (i.e. .cpp files). + +#pragma once + +#include + +namespace mlx::core { + +namespace rocm { +class Device; +} + +struct Dtype; + +// HIP stream managed with RAII. +class HipStream { + public: + explicit HipStream(rocm::Device& device); + ~HipStream(); + + HipStream(const HipStream&) = delete; + HipStream& operator=(const HipStream&) = delete; + + operator hipStream_t() const { + return stream_; + } + + private: + hipStream_t stream_; +}; + +// Throw exception if the HIP API does not succeed. +void check_hip_error(const char* name, hipError_t err); + +// The macro version that prints the command that failed. +#define CHECK_HIP_ERROR(cmd) check_hip_error(#cmd, (cmd)) + +// Convert Dtype to HIP C++ types. +const char* dtype_to_hip_type(const Dtype& dtype); + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/worker.cpp b/mlx/backend/rocm/worker.cpp new file mode 100644 index 000000000..db9d0b45b --- /dev/null +++ b/mlx/backend/rocm/worker.cpp @@ -0,0 +1,76 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/worker.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +Worker::Worker() : worker_thread_(&Worker::worker_loop, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(mutex_); + stop_ = true; + } + cv_.notify_all(); + if (worker_thread_.joinable()) { + worker_thread_.join(); + } +} + +void Worker::add_task(std::function task) { + { + std::lock_guard lock(mutex_); + tasks_.push(task); + } + cv_.notify_one(); +} + +void Worker::consume_in_this_thread() { + std::queue> local_tasks; + { + std::lock_guard lock(mutex_); + local_tasks.swap(tasks_); + } + + while (!local_tasks.empty()) { + auto task = local_tasks.front(); + local_tasks.pop(); + task(); + } +} + +void Worker::commit(hipStream_t stream) { + // Synchronize with stream and then process tasks + CHECK_HIP_ERROR(hipStreamSynchronize(stream)); + consume_in_this_thread(); +} + +void Worker::commit() { + cv_.notify_all(); +} + +void Worker::worker_loop() { + while (true) { + std::function task; + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return stop_ || !tasks_.empty(); }); + + if (stop_) { + break; + } + + if (!tasks_.empty()) { + task = tasks_.front(); + tasks_.pop(); + } + } + + if (task) { + task(); + } + } +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/worker.h b/mlx/backend/rocm/worker.h new file mode 100644 index 000000000..b41fb75c5 --- /dev/null +++ b/mlx/backend/rocm/worker.h @@ -0,0 +1,46 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// Simple worker for async task execution synchronized with HIP streams. +class Worker { + public: + Worker(); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a task to be executed + void add_task(std::function task); + + // Run pending tasks immediately in current thread. + void consume_in_this_thread(); + + // Commit tasks to be run after stream completion + void commit(hipStream_t stream); + + // Simple commit without stream dependency + void commit(); + + private: + void worker_loop(); + + std::thread worker_thread_; + std::queue> tasks_; + std::mutex mutex_; + std::condition_variable cv_; + bool stop_{false}; +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/device.cpp b/mlx/device.cpp index ec17a509a..aec5f40b0 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -6,10 +6,23 @@ #include "mlx/backend/gpu/available.h" #include "mlx/device.h" +#ifdef MLX_USE_ROCM +#include "mlx/backend/rocm/rocm.h" +#endif + namespace mlx::core { Device& mutable_default_device() { - static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; + Device::DeviceType default_type = Device::cpu; + if (gpu::is_available()) { + default_type = Device::gpu; + } +#ifdef MLX_USE_ROCM + else if (rocm::is_available()) { + default_type = Device::gpu; // ROCm devices use the generic gpu type + } +#endif + static Device default_device{default_type}; return default_device; } @@ -38,7 +51,11 @@ bool is_available(const Device& d) { case Device::cpu: return cpu::is_available(); case Device::gpu: +#ifdef MLX_USE_ROCM + return gpu::is_available() || rocm::is_available(); +#else return gpu::is_available(); +#endif } // appease compiler return false;