From 8bb8b76ae49402fab8f8ebe14cb581b61f86c77c Mon Sep 17 00:00:00 2001 From: Nripesh Niketan Date: Mon, 16 Jun 2025 22:42:56 +0100 Subject: [PATCH] [Experiment] ROCM backend initial push --- CMakeLists.txt | 5 ++ mlx/CMakeLists.txt | 11 ++- mlx/backend/rocm/CMakeLists.txt | 85 ++++++++++++++++++ mlx/backend/rocm/allocator.cpp | 20 +++++ mlx/backend/rocm/allocator.h | 12 +++ mlx/backend/rocm/arg_reduce.hip | 28 ++++++ mlx/backend/rocm/bin2h.cmake | 47 ++++++++++ mlx/backend/rocm/binary.hip | 36 ++++++++ mlx/backend/rocm/compiled.cpp | 9 ++ mlx/backend/rocm/copy.hip | 20 +++++ mlx/backend/rocm/device.cpp | 104 ++++++++++++++++++++++ mlx/backend/rocm/device.h | 141 ++++++++++++++++++++++++++++++ mlx/backend/rocm/eval.cpp | 11 +++ mlx/backend/rocm/event.hip | 32 +++++++ mlx/backend/rocm/fence.cpp | 9 ++ mlx/backend/rocm/indexing.cpp | 9 ++ mlx/backend/rocm/kernel_utils.hip | 29 ++++++ mlx/backend/rocm/layer_norm.hip | 37 ++++++++ mlx/backend/rocm/logsumexp.hip | 13 +++ mlx/backend/rocm/matmul.cpp | 30 +++++++ mlx/backend/rocm/no_rocm.cpp | 11 +++ mlx/backend/rocm/primitives.hip | 21 +++++ mlx/backend/rocm/random.hip | 23 +++++ mlx/backend/rocm/reduce.hip | 24 +++++ mlx/backend/rocm/rms_norm.hip | 13 +++ mlx/backend/rocm/rocm.cpp | 11 +++ mlx/backend/rocm/rocm.h | 10 +++ mlx/backend/rocm/rope.hip | 13 +++ mlx/backend/rocm/slicing.cpp | 9 ++ mlx/backend/rocm/softmax.hip | 22 +++++ mlx/backend/rocm/sort.hip | 1 + mlx/backend/rocm/ternary.hip | 20 +++++ mlx/backend/rocm/unary.hip | 33 +++++++ mlx/backend/rocm/utils.cpp | 17 ++++ mlx/backend/rocm/utils.h | 12 +++ mlx/backend/rocm/worker.cpp | 61 +++++++++++++ mlx/backend/rocm/worker.h | 38 ++++++++ mlx/device.cpp | 19 +++- 38 files changed, 1044 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/rocm/CMakeLists.txt create mode 100644 mlx/backend/rocm/allocator.cpp create mode 100644 mlx/backend/rocm/allocator.h create mode 100644 mlx/backend/rocm/arg_reduce.hip create mode 100644 mlx/backend/rocm/bin2h.cmake create mode 100644 mlx/backend/rocm/binary.hip create mode 100644 mlx/backend/rocm/compiled.cpp create mode 100644 mlx/backend/rocm/copy.hip create mode 100644 mlx/backend/rocm/device.cpp create mode 100644 mlx/backend/rocm/device.h create mode 100644 mlx/backend/rocm/eval.cpp create mode 100644 mlx/backend/rocm/event.hip create mode 100644 mlx/backend/rocm/fence.cpp create mode 100644 mlx/backend/rocm/indexing.cpp create mode 100644 mlx/backend/rocm/kernel_utils.hip create mode 100644 mlx/backend/rocm/layer_norm.hip create mode 100644 mlx/backend/rocm/logsumexp.hip create mode 100644 mlx/backend/rocm/matmul.cpp create mode 100644 mlx/backend/rocm/no_rocm.cpp create mode 100644 mlx/backend/rocm/primitives.hip create mode 100644 mlx/backend/rocm/random.hip create mode 100644 mlx/backend/rocm/reduce.hip create mode 100644 mlx/backend/rocm/rms_norm.hip create mode 100644 mlx/backend/rocm/rocm.cpp create mode 100644 mlx/backend/rocm/rocm.h create mode 100644 mlx/backend/rocm/rope.hip create mode 100644 mlx/backend/rocm/slicing.cpp create mode 100644 mlx/backend/rocm/softmax.hip create mode 100644 mlx/backend/rocm/sort.hip create mode 100644 mlx/backend/rocm/ternary.hip create mode 100644 mlx/backend/rocm/unary.hip create mode 100644 mlx/backend/rocm/utils.cpp create mode 100644 mlx/backend/rocm/utils.h create mode 100644 mlx/backend/rocm/worker.cpp create mode 100644 mlx/backend/rocm/worker.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bf8d2d3e..158170647 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CUDA "Build cuda backend" OFF) +option(MLX_BUILD_ROCM "Build ROCm backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -88,6 +89,10 @@ if(MLX_BUILD_CUDA) enable_language(CUDA) endif() +if(MLX_BUILD_ROCM) + enable_language(HIP) +endif() + if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 7aa648533..a4e6260e9 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -60,7 +60,16 @@ else() PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp) endif() -if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) +if(MLX_BUILD_ROCM) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp) +endif() + +if(MLX_BUILD_METAL + OR MLX_BUILD_CUDA + OR MLX_BUILD_ROCM) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) diff --git a/mlx/backend/rocm/CMakeLists.txt b/mlx/backend/rocm/CMakeLists.txt new file mode 100644 index 000000000..260c5128e --- /dev/null +++ b/mlx/backend/rocm/CMakeLists.txt @@ -0,0 +1,85 @@ +# Filename rules in ROCm backend: +# +# * Use .hip/.hpp if code contains device code, and .cpp/.h if not. +# * Device-only code should be put in device/ subdir. +# * Files in device/ subdir should not include files outside. +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/binary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.hip + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.hip + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip + ${CMAKE_CURRENT_SOURCE_DIR}/random.hip + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rope.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip + ${CMAKE_CURRENT_SOURCE_DIR}/sort.hip + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/unary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) + +target_compile_definitions(mlx PRIVATE MLX_USE_ROCM) + +# Embed kernel sources in binary for JIT compilation. +file( + GLOB MLX_JIT_SOURCES + RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp") +string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) +add_custom_command( + OUTPUT gen/rocm_jit_sources.h + COMMAND + ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR} + -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P + "${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake" + DEPENDS bin2h.cmake ${MLX_JIT_SOURCES}) +add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h) +add_dependencies(mlx rocm_jit_sources) +target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") + +# Find ROCm installation +find_package(hip REQUIRED) +find_package(rocblas REQUIRED) + +# Link with ROCm libraries +target_link_libraries(mlx PRIVATE hip::device roc::rocblas) + +# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906, +# gfx908, gfx90a, gfx1030, gfx1100 +set(MLX_ROCM_ARCHITECTURES + "gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100" + CACHE STRING "ROCm GPU architectures") +message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}") + +# Set GPU targets for HIP compilation +set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}") + +# Enable HIP language support +enable_language(HIP) + +# Set HIP compiler flags +target_compile_options( + mlx + PRIVATE "$<$:-fgpu-rdc>" + "$<$:-Xcompiler=-Wall>" + "$<$:-Xcompiler=-Wextra>") + +# Add ROCm include directories +target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS}) +target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS}) diff --git a/mlx/backend/rocm/allocator.cpp b/mlx/backend/rocm/allocator.cpp new file mode 100644 index 000000000..347ab719a --- /dev/null +++ b/mlx/backend/rocm/allocator.cpp @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +void* allocate(size_t size) { + void* ptr; + check_hip_error("hipMalloc", hipMalloc(&ptr, size)); + return ptr; +} + +void deallocate(void* ptr) { + if (ptr) { + check_hip_error("hipFree", hipFree(ptr)); + } +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/allocator.h b/mlx/backend/rocm/allocator.h new file mode 100644 index 000000000..eb8052769 --- /dev/null +++ b/mlx/backend/rocm/allocator.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::rocm { + +void* allocate(size_t size); +void deallocate(void* ptr); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/arg_reduce.hip b/mlx/backend/rocm/arg_reduce.hip new file mode 100644 index 000000000..068625b35 --- /dev/null +++ b/mlx/backend/rocm/arg_reduce.hip @@ -0,0 +1,28 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void argmax_kernel(float* input, int* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Simple argmax placeholder + if (idx == 0) { + int max_idx = 0; + float max_val = input[0]; + for (int i = 1; i < n; i++) { + if (input[i] > max_val) { + max_val = input[i]; + max_idx = i; + } + } + output[0] = max_idx; + } +} + +void launch_argmax(float* input, int* output, int n, hipStream_t stream) { + hipLaunchKernelGGL(argmax_kernel, dim3(1), dim3(1), 0, stream, input, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/bin2h.cmake b/mlx/backend/rocm/bin2h.cmake new file mode 100644 index 000000000..1766b27c9 --- /dev/null +++ b/mlx/backend/rocm/bin2h.cmake @@ -0,0 +1,47 @@ +# Copyright © 2025 Apple Inc. + +# Script to embed kernel source files as header for JIT compilation + +set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h") +set(MLX_KERNEL_HEADER + "#pragma once\n\n#include \n#include \n\nnamespace mlx::core::rocm {\n\n" +) +set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n") + +# Create output directory +get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY) +file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR}) + +# Write header +file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER}) + +# Process JIT sources +string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES}) + +set(MLX_SOURCE_MAP + "const std::unordered_map kernel_sources = {\n") + +foreach(source IN LISTS MLX_JIT_SOURCES_LIST) + set(source_file "${MLX_SOURCE_ROOT}/${source}") + if(EXISTS ${source_file}) + # Read source file + file(READ ${source_file} source_content) + + # Escape content for C++ string literal + string(REPLACE "\\" "\\\\" source_content "${source_content}") + string(REPLACE "\"" "\\\"" source_content "${source_content}") + string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}") + + # Add to map + set(MLX_SOURCE_MAP + "${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n") + endif() +endforeach() + +set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n") + +# Write source map +file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP}) + +# Write footer +file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER}) diff --git a/mlx/backend/rocm/binary.hip b/mlx/backend/rocm/binary.hip new file mode 100644 index 000000000..14b48bfc9 --- /dev/null +++ b/mlx/backend/rocm/binary.hip @@ -0,0 +1,36 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +// Basic binary operation kernels will go here +__global__ void add_kernel(float* a, float* b, float* c, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +__global__ void multiply_kernel(float* a, float* b, float* c, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] * b[idx]; + } +} + +void launch_add(float* a, float* b, float* c, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(add_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n); +} + +void launch_multiply(float* a, float* b, float* c, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(multiply_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/compiled.cpp b/mlx/backend/rocm/compiled.cpp new file mode 100644 index 000000000..a41bc433c --- /dev/null +++ b/mlx/backend/rocm/compiled.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void compile() { + // Placeholder for ROCm compilation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/copy.hip b/mlx/backend/rocm/copy.hip new file mode 100644 index 000000000..4419a2db2 --- /dev/null +++ b/mlx/backend/rocm/copy.hip @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void copy_kernel(float* src, float* dst, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + dst[idx] = src[idx]; + } +} + +void launch_copy(float* src, float* dst, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(copy_kernel, dim3(blocks), dim3(threads), 0, stream, src, dst, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device.cpp b/mlx/backend/rocm/device.cpp new file mode 100644 index 000000000..9ab97ea20 --- /dev/null +++ b/mlx/backend/rocm/device.cpp @@ -0,0 +1,104 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +DeviceStream::DeviceStream(Device& device) : device_(device) { + check_hip_error("hipStreamCreate", hipStreamCreate(&stream_)); + encoder_ = std::make_unique(*this); +} + +void DeviceStream::synchronize() { + check_hip_error("hipStreamSynchronize", hipStreamSynchronize(stream_)); +} + +hipStream_t DeviceStream::schedule_hip_stream() { + return stream_; +} + +hipStream_t DeviceStream::last_hip_stream() { + return stream_; +} + +CommandEncoder& DeviceStream::get_encoder() { + return *encoder_; +} + +Device::Device(int device) : device_(device) { + check_hip_error("hipSetDevice", hipSetDevice(device_)); + + // Get device properties + hipDeviceProp_t prop; + check_hip_error( + "hipGetDeviceProperties", hipGetDeviceProperties(&prop, device_)); + compute_capability_major_ = prop.major; + compute_capability_minor_ = prop.minor; + + // Create rocBLAS handle + check_hip_error( + "rocblas_create_handle", + static_cast(rocblas_create_handle(&rocblas_handle_))); +} + +Device::~Device() { + if (rocblas_handle_) { + rocblas_destroy_handle(rocblas_handle_); + } +} + +void Device::make_current() { + check_hip_error("hipSetDevice", hipSetDevice(device_)); +} + +DeviceStream& Device::get_stream(Stream s) { + auto it = streams_.find(s.index); + if (it != streams_.end()) { + return it->second; + } + + auto [new_it, inserted] = streams_.emplace(s.index, DeviceStream(*this)); + return new_it->second; +} + +CommandEncoder::CommandEncoder(DeviceStream& stream) + : device_(stream.device()), stream_(stream), worker_() {} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_.enqueue(task); +} + +void CommandEncoder::end_encoding() { + // Implementation for ending encoding +} + +void CommandEncoder::commit() { + worker_.commit(); +} + +// Global device management +static std::unordered_map> devices_; + +Device& device(mlx::core::Device device) { + auto it = devices_.find(device.index); + if (it != devices_.end()) { + return *it->second; + } + + auto new_device = std::make_unique(device.index); + Device& dev_ref = *new_device; + devices_[device.index] = std::move(new_device); + return dev_ref; +} + +DeviceStream& get_stream(Stream s) { + // Use default device (index 0) for now + return device(mlx::core::Device{mlx::core::Device::gpu, 0}).get_stream(s); +} + +CommandEncoder& get_command_encoder(Stream s) { + return get_stream(s).get_encoder(); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/device.h b/mlx/backend/rocm/device.h new file mode 100644 index 000000000..bd122d547 --- /dev/null +++ b/mlx/backend/rocm/device.h @@ -0,0 +1,141 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/rocm/worker.h" +#include "mlx/stream.h" + +#include +#include + +#include + +namespace mlx::core::rocm { + +class Device; +class CommandEncoder; + +class DeviceStream { + public: + explicit DeviceStream(Device& device); + + DeviceStream(const DeviceStream&) = delete; + DeviceStream& operator=(const DeviceStream&) = delete; + + // Wait until kernels in the stream complete. + void synchronize(); + + // Return a HIP stream for launching kernels. + hipStream_t schedule_hip_stream(); + + // Return the last HIP stream used. + hipStream_t last_hip_stream(); + + CommandEncoder& get_encoder(); + + Device& device() { + return device_; + } + + private: + Device& device_; + HipStream stream_; + std::unique_ptr encoder_; +}; + +class Device { + public: + explicit Device(int device); + ~Device(); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current HIP device, required by some HIP calls. + void make_current(); + + DeviceStream& get_stream(Stream s); + + int hip_device() const { + return device_; + } + int compute_capability_major() const { + return compute_capability_major_; + } + int compute_capability_minor() const { + return compute_capability_minor_; + } + rocblas_handle rocblas_handle() const { + return rocblas_handle_; + } + + private: + int device_; + int compute_capability_major_; + int compute_capability_minor_; + rocblas_handle rocblas_handle_; + std::unordered_map streams_; +}; + +class CommandEncoder { + public: + explicit CommandEncoder(DeviceStream& stream); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + void set_input_array(const array& arr) {} + void set_output_array(const array& arr) {} + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void end_encoding(); + void commit(); + + // Schedule a HIP stream for |fun| to launch kernels, and check error + // afterwards. + template + void launch_kernel(F&& fun) { + launch_kernel(stream_.schedule_hip_stream(), std::forward(fun)); + } + + template + void launch_kernel(hipStream_t stream, F&& fun) { + device_.make_current(); + fun(stream); + check_hip_error("kernel launch", hipGetLastError()); + has_gpu_work_ = true; + } + + Device& device() { + return device_; + } + + DeviceStream& stream() { + return stream_; + } + + bool has_gpu_work() const { + return has_gpu_work_; + } + + private: + Device& device_; + DeviceStream& stream_; + Worker worker_; + bool has_gpu_work_{false}; + std::vector> temporaries_; +}; + +Device& device(mlx::core::Device device); +DeviceStream& get_stream(Stream s); +CommandEncoder& get_command_encoder(Stream s); + +// Utility function to check HIP errors +void check_hip_error(const char* msg, hipError_t error); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/eval.cpp b/mlx/backend/rocm/eval.cpp new file mode 100644 index 000000000..6fd43c668 --- /dev/null +++ b/mlx/backend/rocm/eval.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +void eval() { + // Placeholder for ROCm evaluation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/event.hip b/mlx/backend/rocm/event.hip new file mode 100644 index 000000000..0358d9e6e --- /dev/null +++ b/mlx/backend/rocm/event.hip @@ -0,0 +1,32 @@ +// Copyright © 2025 Apple Inc. + +#include +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +class Event { +public: + Event() { + check_hip_error("hipEventCreate", hipEventCreate(&event_)); + } + + ~Event() { + hipEventDestroy(event_); + } + + void record(hipStream_t stream) { + check_hip_error("hipEventRecord", hipEventRecord(event_, stream)); + } + + void wait() { + check_hip_error("hipEventSynchronize", hipEventSynchronize(event_)); + } + + hipEvent_t event() const { return event_; } + +private: + hipEvent_t event_; +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/fence.cpp b/mlx/backend/rocm/fence.cpp new file mode 100644 index 000000000..d96c99c06 --- /dev/null +++ b/mlx/backend/rocm/fence.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void fence() { + // Placeholder for ROCm fence operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/indexing.cpp b/mlx/backend/rocm/indexing.cpp new file mode 100644 index 000000000..25e13c36b --- /dev/null +++ b/mlx/backend/rocm/indexing.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void index() { + // Placeholder for ROCm indexing operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/kernel_utils.hip b/mlx/backend/rocm/kernel_utils.hip new file mode 100644 index 000000000..81b3be805 --- /dev/null +++ b/mlx/backend/rocm/kernel_utils.hip @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +// Utility functions for HIP kernels + +__device__ inline int get_global_id() { + return blockIdx.x * blockDim.x + threadIdx.x; +} + +__device__ inline int get_local_id() { + return threadIdx.x; +} + +__device__ inline int get_group_id() { + return blockIdx.x; +} + +__device__ inline int get_local_size() { + return blockDim.x; +} + +__device__ inline int get_num_groups() { + return gridDim.x; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/layer_norm.hip b/mlx/backend/rocm/layer_norm.hip new file mode 100644 index 000000000..c92b667eb --- /dev/null +++ b/mlx/backend/rocm/layer_norm.hip @@ -0,0 +1,37 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void layer_norm_kernel( + float* input, + float* output, + float* gamma, + float* beta, + int n, + float eps) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < n) { + // Simplified layer norm placeholder + // Real implementation would compute mean and variance + output[idx] = gamma[idx] * input[idx] + beta[idx]; + } +} + +void launch_layer_norm( + float* input, + float* output, + float* gamma, + float* beta, + int n, + float eps, + hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(layer_norm_kernel, dim3(blocks), dim3(threads), 0, stream, + input, output, gamma, beta, n, eps); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/logsumexp.hip b/mlx/backend/rocm/logsumexp.hip new file mode 100644 index 000000000..94dfc6525 --- /dev/null +++ b/mlx/backend/rocm/logsumexp.hip @@ -0,0 +1,13 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void logsumexp_kernel(float* input, float* output, int n) { + // Placeholder implementation + int idx = blockIdx.x * blockDim.x + threadIdx.x; + (void)input; (void)output; (void)n; (void)idx; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/matmul.cpp b/mlx/backend/rocm/matmul.cpp new file mode 100644 index 000000000..9d6dbc065 --- /dev/null +++ b/mlx/backend/rocm/matmul.cpp @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +void matmul_hip( + float* a, + float* b, + float* c, + int m, + int n, + int k, + hipStream_t stream) { + // This is a placeholder - in a real implementation, this would use rocBLAS + // auto& device = get_current_device(); + // rocblas_sgemm(device.rocblas_handle(), ...); + + // For now, just a placeholder + (void)a; + (void)b; + (void)c; + (void)m; + (void)n; + (void)k; + (void)stream; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/no_rocm.cpp b/mlx/backend/rocm/no_rocm.cpp new file mode 100644 index 000000000..da686f59d --- /dev/null +++ b/mlx/backend/rocm/no_rocm.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/rocm.h" + +namespace mlx::core::rocm { + +bool is_available() { + return false; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/primitives.hip b/mlx/backend/rocm/primitives.hip new file mode 100644 index 000000000..c91e36da3 --- /dev/null +++ b/mlx/backend/rocm/primitives.hip @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/common/primitives.h" + +namespace mlx::core::rocm { + +// Basic kernel implementations will go here +// This is a placeholder for ROCm-specific primitive operations + +void add_hip() { + // Placeholder for HIP add operation +} + +void multiply_hip() { + // Placeholder for HIP multiply operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/random.hip b/mlx/backend/rocm/random.hip new file mode 100644 index 000000000..d192eb68d --- /dev/null +++ b/mlx/backend/rocm/random.hip @@ -0,0 +1,23 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void random_uniform_kernel(float* output, int n, unsigned int seed) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // Simple LCG placeholder - real implementation would use rocRAND + unsigned int state = seed + idx; + state = state * 1103515245 + 12345; + output[idx] = (float)(state & 0x7FFFFFFF) / (float)0x7FFFFFFF; + } +} + +void launch_random_uniform(float* output, int n, unsigned int seed, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(random_uniform_kernel, dim3(blocks), dim3(threads), 0, stream, output, n, seed); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/reduce.hip b/mlx/backend/rocm/reduce.hip new file mode 100644 index 000000000..6259e9a57 --- /dev/null +++ b/mlx/backend/rocm/reduce.hip @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void sum_reduce_kernel(float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Simple reduction placeholder + if (idx == 0) { + float sum = 0.0f; + for (int i = 0; i < n; i++) { + sum += input[i]; + } + output[0] = sum; + } +} + +void launch_sum_reduce(float* input, float* output, int n, hipStream_t stream) { + hipLaunchKernelGGL(sum_reduce_kernel, dim3(1), dim3(1), 0, stream, input, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/rms_norm.hip b/mlx/backend/rocm/rms_norm.hip new file mode 100644 index 000000000..0d76640a7 --- /dev/null +++ b/mlx/backend/rocm/rms_norm.hip @@ -0,0 +1,13 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void rms_norm_kernel(float* input, float* output, int n) { + // Placeholder implementation + int idx = blockIdx.x * blockDim.x + threadIdx.x; + (void)input; (void)output; (void)n; (void)idx; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/rocm.cpp b/mlx/backend/rocm/rocm.cpp new file mode 100644 index 000000000..83548423a --- /dev/null +++ b/mlx/backend/rocm/rocm.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/rocm.h" + +namespace mlx::core::rocm { + +bool is_available() { + return true; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/rocm.h b/mlx/backend/rocm/rocm.h new file mode 100644 index 000000000..8cc6be67d --- /dev/null +++ b/mlx/backend/rocm/rocm.h @@ -0,0 +1,10 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::rocm { + +/* Check if the ROCm backend is available. */ +bool is_available(); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/rope.hip b/mlx/backend/rocm/rope.hip new file mode 100644 index 000000000..d31da99e8 --- /dev/null +++ b/mlx/backend/rocm/rope.hip @@ -0,0 +1,13 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void rope_kernel(float* input, float* output, int n) { + // Placeholder for RoPE implementation + int idx = blockIdx.x * blockDim.x + threadIdx.x; + (void)input; (void)output; (void)n; (void)idx; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/slicing.cpp b/mlx/backend/rocm/slicing.cpp new file mode 100644 index 000000000..2d5c3e54a --- /dev/null +++ b/mlx/backend/rocm/slicing.cpp @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::rocm { + +void slice() { + // Placeholder for ROCm slicing operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/softmax.hip b/mlx/backend/rocm/softmax.hip new file mode 100644 index 000000000..244e69c61 --- /dev/null +++ b/mlx/backend/rocm/softmax.hip @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void softmax_kernel(float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < n) { + // Simplified softmax placeholder - real implementation needs reduction + output[idx] = expf(input[idx]); + } +} + +void launch_softmax(float* input, float* output, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(softmax_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/sort.hip b/mlx/backend/rocm/sort.hip new file mode 100644 index 000000000..0519ecba6 --- /dev/null +++ b/mlx/backend/rocm/sort.hip @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mlx/backend/rocm/ternary.hip b/mlx/backend/rocm/ternary.hip new file mode 100644 index 000000000..85b75aaf6 --- /dev/null +++ b/mlx/backend/rocm/ternary.hip @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void select_kernel(float* condition, float* a, float* b, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = (condition[idx] != 0.0f) ? a[idx] : b[idx]; + } +} + +void launch_select(float* condition, float* a, float* b, float* output, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(select_kernel, dim3(blocks), dim3(threads), 0, stream, condition, a, b, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/unary.hip b/mlx/backend/rocm/unary.hip new file mode 100644 index 000000000..d9c7f5671 --- /dev/null +++ b/mlx/backend/rocm/unary.hip @@ -0,0 +1,33 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +__global__ void relu_kernel(float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = fmaxf(0.0f, input[idx]); + } +} + +__global__ void sigmoid_kernel(float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = 1.0f / (1.0f + expf(-input[idx])); + } +} + +void launch_relu(float* input, float* output, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(relu_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n); +} + +void launch_sigmoid(float* input, float* output, int n, hipStream_t stream) { + int threads = 256; + int blocks = (n + threads - 1) / threads; + hipLaunchKernelGGL(sigmoid_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/utils.cpp b/mlx/backend/rocm/utils.cpp new file mode 100644 index 000000000..d79aa783e --- /dev/null +++ b/mlx/backend/rocm/utils.cpp @@ -0,0 +1,17 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/utils.h" +#include +#include + +namespace mlx::core::rocm { + +void check_hip_error(const char* msg, hipError_t error) { + if (error != hipSuccess) { + std::ostringstream oss; + oss << "[ROCm] " << msg << ": " << hipGetErrorString(error); + throw std::runtime_error(oss.str()); + } +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/utils.h b/mlx/backend/rocm/utils.h new file mode 100644 index 000000000..20aab3836 --- /dev/null +++ b/mlx/backend/rocm/utils.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::rocm { + +// Utility function to check HIP errors +void check_hip_error(const char* msg, hipError_t error); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/worker.cpp b/mlx/backend/rocm/worker.cpp new file mode 100644 index 000000000..2dbbf98c7 --- /dev/null +++ b/mlx/backend/rocm/worker.cpp @@ -0,0 +1,61 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/worker.h" + +namespace mlx::core::rocm { + +Worker::Worker() : worker_thread_(&Worker::worker_loop, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(mutex_); + stop_ = true; + } + cv_.notify_all(); + if (worker_thread_.joinable()) { + worker_thread_.join(); + } +} + +void Worker::enqueue(std::function task) { + { + std::lock_guard lock(mutex_); + tasks_.push(task); + } + cv_.notify_one(); +} + +void Worker::commit() { + std::lock_guard lock(mutex_); + committed_ = true; +} + +void Worker::join() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return tasks_.empty() && committed_; }); +} + +void Worker::worker_loop() { + while (true) { + std::function task; + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return stop_ || !tasks_.empty(); }); + + if (stop_) { + break; + } + + if (!tasks_.empty()) { + task = tasks_.front(); + tasks_.pop(); + } + } + + if (task) { + task(); + } + } +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/worker.h b/mlx/backend/rocm/worker.h new file mode 100644 index 000000000..a20b0effd --- /dev/null +++ b/mlx/backend/rocm/worker.h @@ -0,0 +1,38 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include + +namespace mlx::core::rocm { + +using HipStream = hipStream_t; + +class Worker { + public: + Worker(); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + void enqueue(std::function task); + void commit(); + void join(); + + private: + void worker_loop(); + + std::thread worker_thread_; + std::queue> tasks_; + std::mutex mutex_; + std::condition_variable cv_; + bool stop_{false}; + bool committed_{false}; +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/device.cpp b/mlx/device.cpp index ec17a509a..aec5f40b0 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -6,10 +6,23 @@ #include "mlx/backend/gpu/available.h" #include "mlx/device.h" +#ifdef MLX_USE_ROCM +#include "mlx/backend/rocm/rocm.h" +#endif + namespace mlx::core { Device& mutable_default_device() { - static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; + Device::DeviceType default_type = Device::cpu; + if (gpu::is_available()) { + default_type = Device::gpu; + } +#ifdef MLX_USE_ROCM + else if (rocm::is_available()) { + default_type = Device::gpu; // ROCm devices use the generic gpu type + } +#endif + static Device default_device{default_type}; return default_device; } @@ -38,7 +51,11 @@ bool is_available(const Device& d) { case Device::cpu: return cpu::is_available(); case Device::gpu: +#ifdef MLX_USE_ROCM + return gpu::is_available() || rocm::is_available(); +#else return gpu::is_available(); +#endif } // appease compiler return false;