[Experiment] ROCM backend initial push

This commit is contained in:
Nripesh Niketan 2025-06-16 22:42:56 +01:00
parent bc53f8293f
commit 8bb8b76ae4
38 changed files with 1044 additions and 2 deletions

View File

@ -35,6 +35,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_BUILD_ROCM "Build ROCm backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@ -88,6 +89,10 @@ if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_ROCM)
enable_language(HIP)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)

View File

@ -60,7 +60,16 @@ else()
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
if(MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp)
endif()
if(MLX_BUILD_METAL
OR MLX_BUILD_CUDA
OR MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)

View File

@ -0,0 +1,85 @@
# Filename rules in ROCm backend:
#
# * Use .hip/.hpp if code contains device code, and .cpp/.h if not.
# * Device-only code should be put in device/ subdir.
# * Files in device/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/binary.hip
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.hip
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.hip
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip
${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip
${CMAKE_CURRENT_SOURCE_DIR}/random.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/rope.hip
${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip
${CMAKE_CURRENT_SOURCE_DIR}/sort.hip
${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip
${CMAKE_CURRENT_SOURCE_DIR}/unary.hip
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)
# Embed kernel sources in binary for JIT compilation.
file(
GLOB MLX_JIT_SOURCES
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp")
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
add_custom_command(
OUTPUT gen/rocm_jit_sources.h
COMMAND
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h)
add_dependencies(mlx rocm_jit_sources)
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
# Find ROCm installation
find_package(hip REQUIRED)
find_package(rocblas REQUIRED)
# Link with ROCm libraries
target_link_libraries(mlx PRIVATE hip::device roc::rocblas)
# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906,
# gfx908, gfx90a, gfx1030, gfx1100
set(MLX_ROCM_ARCHITECTURES
"gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100"
CACHE STRING "ROCm GPU architectures")
message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}")
# Set GPU targets for HIP compilation
set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}")
# Enable HIP language support
enable_language(HIP)
# Set HIP compiler flags
target_compile_options(
mlx
PRIVATE "$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>"
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wall>"
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wextra>")
# Add ROCm include directories
target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS})

View File

@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/allocator.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
void* allocate(size_t size) {
void* ptr;
check_hip_error("hipMalloc", hipMalloc(&ptr, size));
return ptr;
}
void deallocate(void* ptr) {
if (ptr) {
check_hip_error("hipFree", hipFree(ptr));
}
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cstddef>
namespace mlx::core::rocm {
void* allocate(size_t size);
void deallocate(void* ptr);
} // namespace mlx::core::rocm

View File

@ -0,0 +1,28 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void argmax_kernel(float* input, int* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Simple argmax placeholder
if (idx == 0) {
int max_idx = 0;
float max_val = input[0];
for (int i = 1; i < n; i++) {
if (input[i] > max_val) {
max_val = input[i];
max_idx = i;
}
}
output[0] = max_idx;
}
}
void launch_argmax(float* input, int* output, int n, hipStream_t stream) {
hipLaunchKernelGGL(argmax_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,47 @@
# Copyright © 2025 Apple Inc.
# Script to embed kernel source files as header for JIT compilation
set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h")
set(MLX_KERNEL_HEADER
"#pragma once\n\n#include <unordered_map>\n#include <string>\n\nnamespace mlx::core::rocm {\n\n"
)
set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n")
# Create output directory
get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY)
file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR})
# Write header
file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER})
# Process JIT sources
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
set(MLX_SOURCE_MAP
"const std::unordered_map<std::string, std::string> kernel_sources = {\n")
foreach(source IN LISTS MLX_JIT_SOURCES_LIST)
set(source_file "${MLX_SOURCE_ROOT}/${source}")
if(EXISTS ${source_file})
# Read source file
file(READ ${source_file} source_content)
# Escape content for C++ string literal
string(REPLACE "\\" "\\\\" source_content "${source_content}")
string(REPLACE "\"" "\\\"" source_content "${source_content}")
string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}")
# Add to map
set(MLX_SOURCE_MAP
"${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n")
endif()
endforeach()
set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n")
# Write source map
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP})
# Write footer
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER})

View File

@ -0,0 +1,36 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
// Basic binary operation kernels will go here
__global__ void add_kernel(float* a, float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
c[idx] = a[idx] + b[idx];
}
}
__global__ void multiply_kernel(float* a, float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
c[idx] = a[idx] * b[idx];
}
}
void launch_add(float* a, float* b, float* c, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(add_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
}
void launch_multiply(float* a, float* b, float* c, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(multiply_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void compile() {
// Placeholder for ROCm compilation
}
} // namespace mlx::core::rocm

20
mlx/backend/rocm/copy.hip Normal file
View File

@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void copy_kernel(float* src, float* dst, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] = src[idx];
}
}
void launch_copy(float* src, float* dst, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(copy_kernel, dim3(blocks), dim3(threads), 0, stream, src, dst, n);
}
} // namespace mlx::core::rocm

104
mlx/backend/rocm/device.cpp Normal file
View File

@ -0,0 +1,104 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
DeviceStream::DeviceStream(Device& device) : device_(device) {
check_hip_error("hipStreamCreate", hipStreamCreate(&stream_));
encoder_ = std::make_unique<CommandEncoder>(*this);
}
void DeviceStream::synchronize() {
check_hip_error("hipStreamSynchronize", hipStreamSynchronize(stream_));
}
hipStream_t DeviceStream::schedule_hip_stream() {
return stream_;
}
hipStream_t DeviceStream::last_hip_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
return *encoder_;
}
Device::Device(int device) : device_(device) {
check_hip_error("hipSetDevice", hipSetDevice(device_));
// Get device properties
hipDeviceProp_t prop;
check_hip_error(
"hipGetDeviceProperties", hipGetDeviceProperties(&prop, device_));
compute_capability_major_ = prop.major;
compute_capability_minor_ = prop.minor;
// Create rocBLAS handle
check_hip_error(
"rocblas_create_handle",
static_cast<hipError_t>(rocblas_create_handle(&rocblas_handle_)));
}
Device::~Device() {
if (rocblas_handle_) {
rocblas_destroy_handle(rocblas_handle_);
}
}
void Device::make_current() {
check_hip_error("hipSetDevice", hipSetDevice(device_));
}
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it != streams_.end()) {
return it->second;
}
auto [new_it, inserted] = streams_.emplace(s.index, DeviceStream(*this));
return new_it->second;
}
CommandEncoder::CommandEncoder(DeviceStream& stream)
: device_(stream.device()), stream_(stream), worker_() {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.enqueue(task);
}
void CommandEncoder::end_encoding() {
// Implementation for ending encoding
}
void CommandEncoder::commit() {
worker_.commit();
}
// Global device management
static std::unordered_map<int, std::unique_ptr<Device>> devices_;
Device& device(mlx::core::Device device) {
auto it = devices_.find(device.index);
if (it != devices_.end()) {
return *it->second;
}
auto new_device = std::make_unique<Device>(device.index);
Device& dev_ref = *new_device;
devices_[device.index] = std::move(new_device);
return dev_ref;
}
DeviceStream& get_stream(Stream s) {
// Use default device (index 0) for now
return device(mlx::core::Device{mlx::core::Device::gpu, 0}).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace mlx::core::rocm

141
mlx/backend/rocm/device.h Normal file
View File

@ -0,0 +1,141 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/rocm/worker.h"
#include "mlx/stream.h"
#include <hip/hip_runtime.h>
#include <rocblas/rocblas.h>
#include <unordered_map>
namespace mlx::core::rocm {
class Device;
class CommandEncoder;
class DeviceStream {
public:
explicit DeviceStream(Device& device);
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
// Wait until kernels in the stream complete.
void synchronize();
// Return a HIP stream for launching kernels.
hipStream_t schedule_hip_stream();
// Return the last HIP stream used.
hipStream_t last_hip_stream();
CommandEncoder& get_encoder();
Device& device() {
return device_;
}
private:
Device& device_;
HipStream stream_;
std::unique_ptr<CommandEncoder> encoder_;
};
class Device {
public:
explicit Device(int device);
~Device();
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current HIP device, required by some HIP calls.
void make_current();
DeviceStream& get_stream(Stream s);
int hip_device() const {
return device_;
}
int compute_capability_major() const {
return compute_capability_major_;
}
int compute_capability_minor() const {
return compute_capability_minor_;
}
rocblas_handle rocblas_handle() const {
return rocblas_handle_;
}
private:
int device_;
int compute_capability_major_;
int compute_capability_minor_;
rocblas_handle rocblas_handle_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a HIP stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_hip_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(hipStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_hip_error("kernel launch", hipGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Utility function to check HIP errors
void check_hip_error(const char* msg, hipError_t error);
} // namespace mlx::core::rocm

11
mlx/backend/rocm/eval.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
void eval() {
// Placeholder for ROCm evaluation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,32 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
class Event {
public:
Event() {
check_hip_error("hipEventCreate", hipEventCreate(&event_));
}
~Event() {
hipEventDestroy(event_);
}
void record(hipStream_t stream) {
check_hip_error("hipEventRecord", hipEventRecord(event_, stream));
}
void wait() {
check_hip_error("hipEventSynchronize", hipEventSynchronize(event_));
}
hipEvent_t event() const { return event_; }
private:
hipEvent_t event_;
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void fence() {
// Placeholder for ROCm fence operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void index() {
// Placeholder for ROCm indexing operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,29 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
// Utility functions for HIP kernels
__device__ inline int get_global_id() {
return blockIdx.x * blockDim.x + threadIdx.x;
}
__device__ inline int get_local_id() {
return threadIdx.x;
}
__device__ inline int get_group_id() {
return blockIdx.x;
}
__device__ inline int get_local_size() {
return blockDim.x;
}
__device__ inline int get_num_groups() {
return gridDim.x;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,37 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void layer_norm_kernel(
float* input,
float* output,
float* gamma,
float* beta,
int n,
float eps) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// Simplified layer norm placeholder
// Real implementation would compute mean and variance
output[idx] = gamma[idx] * input[idx] + beta[idx];
}
}
void launch_layer_norm(
float* input,
float* output,
float* gamma,
float* beta,
int n,
float eps,
hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(layer_norm_kernel, dim3(blocks), dim3(threads), 0, stream,
input, output, gamma, beta, n, eps);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,13 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void logsumexp_kernel(float* input, float* output, int n) {
// Placeholder implementation
int idx = blockIdx.x * blockDim.x + threadIdx.x;
(void)input; (void)output; (void)n; (void)idx;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,30 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
void matmul_hip(
float* a,
float* b,
float* c,
int m,
int n,
int k,
hipStream_t stream) {
// This is a placeholder - in a real implementation, this would use rocBLAS
// auto& device = get_current_device();
// rocblas_sgemm(device.rocblas_handle(), ...);
// For now, just a placeholder
(void)a;
(void)b;
(void)c;
(void)m;
(void)n;
(void)k;
(void)stream;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/rocm.h"
namespace mlx::core::rocm {
bool is_available() {
return false;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,21 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/common/primitives.h"
namespace mlx::core::rocm {
// Basic kernel implementations will go here
// This is a placeholder for ROCm-specific primitive operations
void add_hip() {
// Placeholder for HIP add operation
}
void multiply_hip() {
// Placeholder for HIP multiply operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,23 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void random_uniform_kernel(float* output, int n, unsigned int seed) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// Simple LCG placeholder - real implementation would use rocRAND
unsigned int state = seed + idx;
state = state * 1103515245 + 12345;
output[idx] = (float)(state & 0x7FFFFFFF) / (float)0x7FFFFFFF;
}
}
void launch_random_uniform(float* output, int n, unsigned int seed, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(random_uniform_kernel, dim3(blocks), dim3(threads), 0, stream, output, n, seed);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,24 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void sum_reduce_kernel(float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Simple reduction placeholder
if (idx == 0) {
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += input[i];
}
output[0] = sum;
}
}
void launch_sum_reduce(float* input, float* output, int n, hipStream_t stream) {
hipLaunchKernelGGL(sum_reduce_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,13 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void rms_norm_kernel(float* input, float* output, int n) {
// Placeholder implementation
int idx = blockIdx.x * blockDim.x + threadIdx.x;
(void)input; (void)output; (void)n; (void)idx;
}
} // namespace mlx::core::rocm

11
mlx/backend/rocm/rocm.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/rocm.h"
namespace mlx::core::rocm {
bool is_available() {
return true;
}
} // namespace mlx::core::rocm

10
mlx/backend/rocm/rocm.h Normal file
View File

@ -0,0 +1,10 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::rocm {
/* Check if the ROCm backend is available. */
bool is_available();
} // namespace mlx::core::rocm

13
mlx/backend/rocm/rope.hip Normal file
View File

@ -0,0 +1,13 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void rope_kernel(float* input, float* output, int n) {
// Placeholder for RoPE implementation
int idx = blockIdx.x * blockDim.x + threadIdx.x;
(void)input; (void)output; (void)n; (void)idx;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void slice() {
// Placeholder for ROCm slicing operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void softmax_kernel(float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// Simplified softmax placeholder - real implementation needs reduction
output[idx] = expf(input[idx]);
}
}
void launch_softmax(float* input, float* output, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(softmax_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void select_kernel(float* condition, float* a, float* b, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = (condition[idx] != 0.0f) ? a[idx] : b[idx];
}
}
void launch_select(float* condition, float* a, float* b, float* output, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(select_kernel, dim3(blocks), dim3(threads), 0, stream, condition, a, b, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,33 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void relu_kernel(float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = fmaxf(0.0f, input[idx]);
}
}
__global__ void sigmoid_kernel(float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = 1.0f / (1.0f + expf(-input[idx]));
}
}
void launch_relu(float* input, float* output, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(relu_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
}
void launch_sigmoid(float* input, float* output, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(sigmoid_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,17 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/utils.h"
#include <sstream>
#include <stdexcept>
namespace mlx::core::rocm {
void check_hip_error(const char* msg, hipError_t error) {
if (error != hipSuccess) {
std::ostringstream oss;
oss << "[ROCm] " << msg << ": " << hipGetErrorString(error);
throw std::runtime_error(oss.str());
}
}
} // namespace mlx::core::rocm

12
mlx/backend/rocm/utils.h Normal file
View File

@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
// Utility function to check HIP errors
void check_hip_error(const char* msg, hipError_t error);
} // namespace mlx::core::rocm

View File

@ -0,0 +1,61 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/worker.h"
namespace mlx::core::rocm {
Worker::Worker() : worker_thread_(&Worker::worker_loop, this) {}
Worker::~Worker() {
{
std::lock_guard<std::mutex> lock(mutex_);
stop_ = true;
}
cv_.notify_all();
if (worker_thread_.joinable()) {
worker_thread_.join();
}
}
void Worker::enqueue(std::function<void()> task) {
{
std::lock_guard<std::mutex> lock(mutex_);
tasks_.push(task);
}
cv_.notify_one();
}
void Worker::commit() {
std::lock_guard<std::mutex> lock(mutex_);
committed_ = true;
}
void Worker::join() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return tasks_.empty() && committed_; });
}
void Worker::worker_loop() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return stop_ || !tasks_.empty(); });
if (stop_) {
break;
}
if (!tasks_.empty()) {
task = tasks_.front();
tasks_.pop();
}
}
if (task) {
task();
}
}
}
} // namespace mlx::core::rocm

38
mlx/backend/rocm/worker.h Normal file
View File

@ -0,0 +1,38 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <functional>
#include <future>
#include <queue>
#include <thread>
namespace mlx::core::rocm {
using HipStream = hipStream_t;
class Worker {
public:
Worker();
~Worker();
Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;
void enqueue(std::function<void()> task);
void commit();
void join();
private:
void worker_loop();
std::thread worker_thread_;
std::queue<std::function<void()>> tasks_;
std::mutex mutex_;
std::condition_variable cv_;
bool stop_{false};
bool committed_{false};
};
} // namespace mlx::core::rocm

View File

@ -6,10 +6,23 @@
#include "mlx/backend/gpu/available.h"
#include "mlx/device.h"
#ifdef MLX_USE_ROCM
#include "mlx/backend/rocm/rocm.h"
#endif
namespace mlx::core {
Device& mutable_default_device() {
static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
Device::DeviceType default_type = Device::cpu;
if (gpu::is_available()) {
default_type = Device::gpu;
}
#ifdef MLX_USE_ROCM
else if (rocm::is_available()) {
default_type = Device::gpu; // ROCm devices use the generic gpu type
}
#endif
static Device default_device{default_type};
return default_device;
}
@ -38,7 +51,11 @@ bool is_available(const Device& d) {
case Device::cpu:
return cpu::is_available();
case Device::gpu:
#ifdef MLX_USE_ROCM
return gpu::is_available() || rocm::is_available();
#else
return gpu::is_available();
#endif
}
// appease compiler
return false;