This commit is contained in:
Nripesh Niketan 2025-06-21 02:23:45 +01:00 committed by GitHub
commit c4b30485f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
59 changed files with 5746 additions and 2 deletions

View File

@ -35,6 +35,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_BUILD_ROCM "Build ROCm backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@ -88,6 +89,10 @@ if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_ROCM)
enable_language(HIP)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)

View File

@ -60,7 +60,16 @@ else()
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
if(MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp)
endif()
if(MLX_BUILD_METAL
OR MLX_BUILD_CUDA
OR MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)

View File

@ -0,0 +1,85 @@
# Filename rules in ROCm backend:
#
# * Use .hip/.hpp if code contains device code, and .cpp/.h if not.
# * Device-only code should be put in device/ subdir.
# * Files in device/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/binary.hip
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.hip
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.hip
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip
${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip
${CMAKE_CURRENT_SOURCE_DIR}/random.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/rope.hip
${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip
${CMAKE_CURRENT_SOURCE_DIR}/sort.hip
${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip
${CMAKE_CURRENT_SOURCE_DIR}/unary.hip
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)
# Embed kernel sources in binary for JIT compilation.
file(
GLOB MLX_JIT_SOURCES
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp")
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
add_custom_command(
OUTPUT gen/rocm_jit_sources.h
COMMAND
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h)
add_dependencies(mlx rocm_jit_sources)
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
# Find ROCm installation
find_package(hip REQUIRED)
find_package(rocblas REQUIRED)
# Link with ROCm libraries
target_link_libraries(mlx PRIVATE hip::device roc::rocblas)
# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906,
# gfx908, gfx90a, gfx1030, gfx1100
set(MLX_ROCM_ARCHITECTURES
"gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100"
CACHE STRING "ROCm GPU architectures")
message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}")
# Set GPU targets for HIP compilation
set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}")
# Enable HIP language support
enable_language(HIP)
# Set HIP compiler flags
target_compile_options(
mlx
PRIVATE "$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>"
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wall>"
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wextra>")
# Add ROCm include directories
target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS})

View File

@ -0,0 +1,206 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/allocator.h"
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/rocm/worker.h"
#include <fmt/format.h>
#include <hip/hip_runtime.h>
#include <unistd.h>
#include <cassert>
namespace mlx::core {
namespace rocm {
RocmAllocator::RocmAllocator()
: buffer_cache_(
getpagesize(),
[](RocmBuffer* buf) { return buf->size; },
[this](RocmBuffer* buf) {
rocm_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_HIP_ERROR(hipMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
}
Buffer RocmAllocator::malloc(size_t size) {
// Find available buffer from cache.
std::unique_lock lock(mutex_);
RocmBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
lock.unlock();
buf = new RocmBuffer{nullptr, size};
hipError_t err = hipMallocManaged(&buf->data, size);
if (err != hipSuccess && err != hipErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("hipMallocManaged failed: {}.", hipGetErrorString(err)));
}
lock.lock();
}
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
void RocmAllocator::free(Buffer buffer) {
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
if (!buf) {
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
lock.unlock();
rocm_free(buf->data);
delete buf;
}
}
size_t RocmAllocator::size(Buffer buffer) const {
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}
void RocmAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
void RocmAllocator::rocm_free(void* buf) {
// If rocm_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->rocm_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
hipFree(buf);
}
size_t RocmAllocator::get_active_memory() const {
return active_memory_;
}
size_t RocmAllocator::get_peak_memory() const {
return peak_memory_;
}
void RocmAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}
size_t RocmAllocator::get_memory_limit() {
return memory_limit_;
}
size_t RocmAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}
size_t RocmAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t RocmAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}
void RocmAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}
RocmAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of RocmAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static RocmAllocator* allocator_ = new RocmAllocator;
return *allocator_;
}
} // namespace rocm
namespace allocator {
Allocator& allocator() {
return rocm::allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<rocm::RocmBuffer*>(ptr_)->data;
}
} // namespace allocator
size_t get_active_memory() {
return rocm::allocator().get_active_memory();
}
size_t get_peak_memory() {
return rocm::allocator().get_peak_memory();
}
void reset_peak_memory() {
return rocm::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return rocm::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return rocm::allocator().get_memory_limit();
}
size_t get_cache_memory() {
return rocm::allocator().get_cache_memory();
}
size_t set_cache_limit(size_t limit) {
return rocm::allocator().set_cache_limit(limit);
}
void clear_cache() {
rocm::allocator().clear_cache();
}
// Not supported in ROCm.
size_t set_wired_limit(size_t) {
return 0;
}
} // namespace mlx::core

View File

@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::rocm {
class Worker;
using allocator::Buffer;
// Stores ROCm-managed unified memory.
struct RocmBuffer {
void* data;
size_t size;
};
class RocmAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In ROCm freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call hipFree in the safe thread.
void rocm_free(void* buf);
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();
private:
RocmAllocator();
friend RocmAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<RocmBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
};
RocmAllocator& allocator();
} // namespace mlx::core::rocm

View File

@ -0,0 +1,28 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void argmax_kernel(float* input, int* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Simple argmax placeholder
if (idx == 0) {
int max_idx = 0;
float max_val = input[0];
for (int i = 1; i < n; i++) {
if (input[i] > max_val) {
max_val = input[i];
max_idx = i;
}
}
output[0] = max_idx;
}
}
void launch_argmax(float* input, int* output, int n, hipStream_t stream) {
hipLaunchKernelGGL(argmax_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,47 @@
# Copyright © 2025 Apple Inc.
# Script to embed kernel source files as header for JIT compilation
set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h")
set(MLX_KERNEL_HEADER
"#pragma once\n\n#include <unordered_map>\n#include <string>\n\nnamespace mlx::core::rocm {\n\n"
)
set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n")
# Create output directory
get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY)
file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR})
# Write header
file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER})
# Process JIT sources
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
set(MLX_SOURCE_MAP
"const std::unordered_map<std::string, std::string> kernel_sources = {\n")
foreach(source IN LISTS MLX_JIT_SOURCES_LIST)
set(source_file "${MLX_SOURCE_ROOT}/${source}")
if(EXISTS ${source_file})
# Read source file
file(READ ${source_file} source_content)
# Escape content for C++ string literal
string(REPLACE "\\" "\\\\" source_content "${source_content}")
string(REPLACE "\"" "\\\"" source_content "${source_content}")
string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}")
# Add to map
set(MLX_SOURCE_MAP
"${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n")
endif()
endforeach()
set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n")
# Write source map
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP})
# Write footer
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER})

312
mlx/backend/rocm/binary.hip Normal file
View File

@ -0,0 +1,312 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/device/binary_ops.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <hip/hip_cooperative_groups.h>
namespace mlx::core {
namespace rocm {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
const hip_array<int32_t, NDIM> shape,
const hip_array<int64_t, NDIM> a_strides,
const hip_array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
const hip_array<int32_t, MAX_DIMS> shape,
const hip_array<int64_t, MAX_DIMS> a_strides,
const hip_array<int64_t, MAX_DIMS> b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
// Binary operation support checking
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
return std::is_same_v<Out, bool>;
}
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> && is_inexact_v<In>;
}
if (std::is_same_v<Op, LogAddExp>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
std::is_same_v<Op, BitwiseXor>) {
return std::is_same_v<In, Out> && std::is_integral_v<In>;
}
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
return false;
}
} // namespace rocm
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (rocm::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = hip_type_t<CTYPE_IN>;
using OutType = hip_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&rocm::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
make_hip_array<NDIM>(shape),
make_hip_array<NDIM>(a_strides),
make_hip_array<NDIM>(b_strides));
});
} else {
auto kernel = rocm::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
make_hip_array(shape),
make_hip_array(a_strides),
make_hip_array(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = rocm::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = rocm::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = rocm::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = rocm::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
auto& s = out.primitive().stream(); \
binary_op_gpu<rocm::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<rocm::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<rocm::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<rocm::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<rocm::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<rocm::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<rocm::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<rocm::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<rocm::RightShift>(inputs, out, op, s);
break;
}
}
} // namespace mlx::core

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void compile() {
// Placeholder for ROCm compilation
}
} // namespace mlx::core::rocm

20
mlx/backend/rocm/copy.hip Normal file
View File

@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void copy_kernel(float* src, float* dst, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] = src[idx];
}
}
void launch_copy(float* src, float* dst, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(copy_kernel, dim3(blocks), dim3(threads), 0, stream, src, dst, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,60 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <cstddef>
namespace mlx::core::rocm {
// Copy function declarations
void copy_contiguous(
const void* src,
void* dst,
size_t size,
hipStream_t stream);
void copy_general(
const void* src,
void* dst,
const int* src_shape,
const size_t* src_strides,
const int* dst_shape,
const size_t* dst_strides,
int ndim,
size_t size,
size_t dtype_size,
hipStream_t stream);
void copy_general_dynamic(
const void* src,
void* dst,
const int* src_shape,
const size_t* src_strides,
const int* dst_shape,
const size_t* dst_strides,
int ndim,
size_t size,
size_t dtype_size,
hipStream_t stream);
void copy_general_input(
const void* src,
void* dst,
const int* src_shape,
const size_t* src_strides,
const int* dst_shape,
const size_t* dst_strides,
int ndim,
size_t size,
size_t dtype_size,
hipStream_t stream);
// Utility functions for element location calculation
__device__ size_t
elem_to_loc(size_t elem, const int* shape, const size_t* strides, int ndim);
__device__ size_t
loc_to_elem(size_t loc, const int* shape, const size_t* strides, int ndim);
} // namespace mlx::core::rocm

View File

@ -0,0 +1,38 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/copy/copy.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void copy_contiguous_kernel(
const char* src,
char* dst,
size_t size) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < size) {
dst[tid] = src[tid];
}
}
void copy_contiguous(
const void* src,
void* dst,
size_t size,
hipStream_t stream) {
if (size == 0) {
return;
}
const int threads_per_block = 256;
const int blocks = (size + threads_per_block - 1) / threads_per_block;
copy_contiguous_kernel<<<blocks, threads_per_block, 0, stream>>>(
static_cast<const char*>(src),
static_cast<char*>(dst),
size);
}
} // namespace mlx::core::rocm

130
mlx/backend/rocm/device.cpp Normal file
View File

@ -0,0 +1,130 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/rocm/worker.h"
#include <fmt/format.h>
namespace mlx::core {
namespace rocm {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
CHECK_HIP_ERROR(hipStreamSynchronize(stream_));
}
hipStream_t DeviceStream::schedule_hip_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
hipStream_t DeviceStream::last_hip_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) {
CHECK_HIP_ERROR(hipDeviceGetAttribute(
&compute_capability_major_,
hipDeviceAttributeComputeCapabilityMajor,
device_));
CHECK_HIP_ERROR(hipDeviceGetAttribute(
&compute_capability_minor_,
hipDeviceAttributeComputeCapabilityMinor,
device_));
// Validate device requirements
int attr = 0;
CHECK_HIP_ERROR(hipDeviceGetAttribute(
&attr, hipDeviceAttributeConcurrentManagedAccess, device_));
if (attr != 1) {
// ROCm unified memory might not be available on all devices
// This is a warning rather than an error for ROCm
// TODO: Add proper ROCm unified memory checking
}
// Create rocBLAS handle
make_current();
CHECK_HIP_ERROR(
static_cast<hipError_t>(rocblas_create_handle(&rocblas_handle_)));
}
Device::~Device() {
if (rocblas_handle_) {
rocblas_destroy_handle(rocblas_handle_);
}
}
void Device::make_current() {
// Cache current device to reduce HIP API calls
static int current = 0;
if (current != device_) {
CHECK_HIP_ERROR(hipSetDevice(device_));
current = device_;
}
}
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it == streams_.end()) {
it = streams_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(DeviceStream& s)
: device_(s.device()), stream_(s) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::end_encoding() {
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
// There is no kernel running, run completion handlers immediately.
if (!has_gpu_work_) {
worker_.consume_in_this_thread();
return;
}
has_gpu_work_ = false;
// Commit tasks
commit();
}
void CommandEncoder::commit() {
worker_.commit(stream_.last_hip_stream());
}
Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);
if (it == devices.end()) {
it = devices.try_emplace(device.index, device.index).first;
}
return it->second;
}
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace rocm
} // namespace mlx::core

146
mlx/backend/rocm/device.h Normal file
View File

@ -0,0 +1,146 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/rocm/worker.h"
#include "mlx/stream.h"
#include <hip/hip_runtime.h>
#include <rocblas/rocblas.h>
#include <unordered_map>
namespace mlx::core {
namespace rocm {
class Device;
class CommandEncoder;
class DeviceStream {
public:
explicit DeviceStream(Device& device);
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
// Wait until kernels in the stream complete.
void synchronize();
// Return a HIP stream for launching kernels.
hipStream_t schedule_hip_stream();
// Return the last HIP stream used.
hipStream_t last_hip_stream();
CommandEncoder& get_encoder();
Device& device() {
return device_;
}
private:
Device& device_;
HipStream stream_;
std::unique_ptr<CommandEncoder> encoder_;
};
class Device {
public:
explicit Device(int device);
~Device();
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current HIP device, required by some HIP calls.
void make_current();
DeviceStream& get_stream(Stream s);
int hip_device() const {
return device_;
}
int compute_capability_major() const {
return compute_capability_major_;
}
int compute_capability_minor() const {
return compute_capability_minor_;
}
rocblas_handle rocblas_handle() const {
return rocblas_handle_;
}
private:
int device_;
int compute_capability_major_;
int compute_capability_minor_;
rocblas_handle rocblas_handle_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a HIP stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_hip_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(hipStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_hip_error("kernel launch", hipGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Utility function to check HIP errors
void check_hip_error(const char* msg, hipError_t error);
} // namespace rocm
} // namespace mlx::core

View File

@ -0,0 +1,17 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
template <typename T>
__global__ void arange_kernel(T* out, T start, T step, size_t size) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < size) {
out[tid] = start + static_cast<T>(tid) * step;
}
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,36 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
// Atomic operations for HIP
__device__ inline float atomicAddFloat(float* address, float val) {
return atomicAdd(address, val);
}
__device__ inline double atomicAddDouble(double* address, double val) {
return atomicAdd(address, val);
}
__device__ inline int atomicAddInt(int* address, int val) {
return atomicAdd(address, val);
}
__device__ inline unsigned int atomicAddUInt(
unsigned int* address,
unsigned int val) {
return atomicAdd(address, val);
}
__device__ inline float atomicMaxFloat(float* address, float val) {
return atomicMax(address, val);
}
__device__ inline float atomicMinFloat(float* address, float val) {
return atomicMin(address, val);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hipcomplex.h>
namespace mlx::core::rocm {
// Arithmetic operations
struct Add {
template <typename T>
__device__ T operator()(T a, T b) {
return a + b;
}
};
struct Subtract {
template <typename T>
__device__ T operator()(T a, T b) {
return a - b;
}
};
struct Multiply {
template <typename T>
__device__ T operator()(T a, T b) {
return a * b;
}
};
struct Divide {
template <typename T>
__device__ T operator()(T a, T b) {
return a / b;
}
};
struct Power {
template <typename T>
__device__ T operator()(T a, T b) {
return powf(a, b);
}
__device__ double operator()(double a, double b) {
return pow(a, b);
}
};
struct Remainder {
template <typename T>
__device__ T operator()(T a, T b) {
return fmodf(a, b);
}
__device__ double operator()(double a, double b) {
return fmod(a, b);
}
};
// Comparison operations
struct Equal {
template <typename T>
__device__ bool operator()(T a, T b) {
return a == b;
}
};
struct NotEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return a != b;
}
};
struct Greater {
template <typename T>
__device__ bool operator()(T a, T b) {
return a > b;
}
};
struct GreaterEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return a >= b;
}
};
struct Less {
template <typename T>
__device__ bool operator()(T a, T b) {
return a < b;
}
};
struct LessEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return a <= b;
}
};
struct NaNEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return (isnan(a) && isnan(b)) || (a == b);
}
};
// Logic operations
struct LogicalAnd {
__device__ bool operator()(bool a, bool b) {
return a && b;
}
};
struct LogicalOr {
__device__ bool operator()(bool a, bool b) {
return a || b;
}
};
// Math operations
struct Maximum {
template <typename T>
__device__ T operator()(T a, T b) {
return fmaxf(a, b);
}
__device__ double operator()(double a, double b) {
return fmax(a, b);
}
};
struct Minimum {
template <typename T>
__device__ T operator()(T a, T b) {
return fminf(a, b);
}
__device__ double operator()(double a, double b) {
return fmin(a, b);
}
};
struct LogAddExp {
template <typename T>
__device__ T operator()(T a, T b) {
T max_val = fmaxf(a, b);
T min_val = fminf(a, b);
if (isinf(max_val)) {
return max_val;
}
return max_val + log1pf(expf(min_val - max_val));
}
__device__ double operator()(double a, double b) {
double max_val = fmax(a, b);
double min_val = fmin(a, b);
if (isinf(max_val)) {
return max_val;
}
return max_val + log1p(exp(min_val - max_val));
}
};
struct ArcTan2 {
template <typename T>
__device__ T operator()(T a, T b) {
return atan2f(a, b);
}
__device__ double operator()(double a, double b) {
return atan2(a, b);
}
};
// Bitwise operations
struct BitwiseAnd {
template <typename T>
__device__ T operator()(T a, T b) {
return a & b;
}
};
struct BitwiseOr {
template <typename T>
__device__ T operator()(T a, T b) {
return a | b;
}
};
struct BitwiseXor {
template <typename T>
__device__ T operator()(T a, T b) {
return a ^ b;
}
};
struct LeftShift {
template <typename T>
__device__ T operator()(T a, T b) {
return a << b;
}
};
struct RightShift {
template <typename T>
__device__ T operator()(T a, T b) {
return a >> b;
}
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,21 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
template <typename To, typename From>
struct CastOp {
__device__ To operator()(From x) const {
return static_cast<To>(x);
}
};
template <typename To, typename From>
__device__ inline To cast_op(From x) {
return static_cast<To>(x);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,14 @@
// Copyright © 2025 Apple Inc.
#pragma once
// ROCm/HIP specific configuration
#define ROCM_MAX_THREADS_PER_BLOCK 1024
#define ROCM_WARP_SIZE 64
#define ROCM_MAX_BLOCKS_PER_GRID 65535
namespace mlx::core::rocm {
constexpr int kMaxThreadsPerBlock = ROCM_MAX_THREADS_PER_BLOCK;
constexpr int kWarpSize = ROCM_WARP_SIZE;
constexpr int kMaxBlocksPerGrid = ROCM_MAX_BLOCKS_PER_GRID;
} // namespace mlx::core::rocm

View File

@ -0,0 +1,87 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
// HIP/ROCm equivalents of CUDA half precision math functions
inline __device__ __half2 h2sin(__half2 x) {
return __half2{hsin(x.x), hsin(x.y)};
}
inline __device__ __half2 h2cos(__half2 x) {
return __half2{hcos(x.x), hcos(x.y)};
}
inline __device__ __half2 h2exp(__half2 x) {
return __half2{hexp(x.x), hexp(x.y)};
}
inline __device__ __half2 h2log(__half2 x) {
return __half2{hlog(x.x), hlog(x.y)};
}
inline __device__ __half2 h2sqrt(__half2 x) {
return __half2{hsqrt(x.x), hsqrt(x.y)};
}
inline __device__ __half2 h2rsqrt(__half2 x) {
return __half2{hrsqrt(x.x), hrsqrt(x.y)};
}
inline __device__ __half2 h2ceil(__half2 x) {
return __half2{hceil(x.x), hceil(x.y)};
}
inline __device__ __half2 h2floor(__half2 x) {
return __half2{hfloor(x.x), hfloor(x.y)};
}
inline __device__ __half2 h2rint(__half2 x) {
return __half2{hrint(x.x), hrint(x.y)};
}
inline __device__ __half2 h2trunc(__half2 x) {
return __half2{htrunc(x.x), htrunc(x.y)};
}
// Additional math functions for half precision
inline __device__ __half habs(__half x) {
return __half{fabsf(__half2float(x))};
}
inline __device__ __half2 h2abs(__half2 x) {
return __half2{habs(x.x), habs(x.y)};
}
inline __device__ __half hneg(__half x) {
return __half{-__half2float(x)};
}
inline __device__ __half2 h2neg(__half2 x) {
return __half2{hneg(x.x), hneg(x.y)};
}
// BFloat16 support functions
#ifdef __HIP_BFLOAT16__
inline __device__ __hip_bfloat16 habs(__hip_bfloat16 x) {
return __hip_bfloat16{fabsf(__bfloat162float(x))};
}
inline __device__ __hip_bfloat162 h2abs(__hip_bfloat162 x) {
return __hip_bfloat162{habs(x.x), habs(x.y)};
}
inline __device__ __hip_bfloat16 hneg(__hip_bfloat16 x) {
return __hip_bfloat16{-__bfloat162float(x)};
}
inline __device__ __hip_bfloat162 h2neg(__hip_bfloat162 x) {
return __hip_bfloat162{hneg(x.x), hneg(x.y)};
}
#endif
} // namespace mlx::core::rocm

View File

@ -0,0 +1,52 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_complex.h>
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
// HIP complex math functions
__device__ inline hipFloatComplex hip_complex_add(
hipFloatComplex a,
hipFloatComplex b) {
return make_hipFloatComplex(
hipCrealf(a) + hipCrealf(b), hipCimagf(a) + hipCimagf(b));
}
__device__ inline hipFloatComplex hip_complex_sub(
hipFloatComplex a,
hipFloatComplex b) {
return make_hipFloatComplex(
hipCrealf(a) - hipCrealf(b), hipCimagf(a) - hipCimagf(b));
}
__device__ inline hipFloatComplex hip_complex_mul(
hipFloatComplex a,
hipFloatComplex b) {
float real = hipCrealf(a) * hipCrealf(b) - hipCimagf(a) * hipCimagf(b);
float imag = hipCrealf(a) * hipCimagf(b) + hipCimagf(a) * hipCrealf(b);
return make_hipFloatComplex(real, imag);
}
__device__ inline hipFloatComplex hip_complex_div(
hipFloatComplex a,
hipFloatComplex b) {
float denom = hipCrealf(b) * hipCrealf(b) + hipCimagf(b) * hipCimagf(b);
float real =
(hipCrealf(a) * hipCrealf(b) + hipCimagf(a) * hipCimagf(b)) / denom;
float imag =
(hipCimagf(a) * hipCrealf(b) - hipCrealf(a) * hipCimagf(b)) / denom;
return make_hipFloatComplex(real, imag);
}
__device__ inline float hip_complex_abs(hipFloatComplex z) {
return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z));
}
__device__ inline hipFloatComplex hip_complex_conj(hipFloatComplex z) {
return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z));
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
struct Select {
template <typename T>
__device__ T operator()(bool condition, T a, T b) const {
return condition ? a : b;
}
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,368 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/rocm/device/fp16_math.hpp"
#include "mlx/backend/rocm/device/utils.hpp"
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
struct Abs {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_unsigned_v<T>) {
return x;
} else if constexpr (std::is_same_v<T, hipFloatComplex>) {
return {
sqrt(hipCrealf(x) * hipCrealf(x) + hipCimagf(x) * hipCimagf(x)), 0};
} else {
return abs(x);
}
}
};
struct ArcCos {
template <typename T>
__device__ T operator()(T x) {
return acos(x);
}
};
struct ArcCosh {
template <typename T>
__device__ T operator()(T x) {
return acosh(x);
}
};
struct ArcSin {
template <typename T>
__device__ T operator()(T x) {
return asin(x);
}
};
struct ArcSinh {
template <typename T>
__device__ T operator()(T x) {
return asinh(x);
}
};
struct ArcTan {
template <typename T>
__device__ T operator()(T x) {
return atan(x);
}
};
struct ArcTanh {
template <typename T>
__device__ T operator()(T x) {
return atanh(x);
}
};
struct BitwiseInvert {
template <typename T>
__device__ T operator()(T x) {
return ~x;
}
};
struct Ceil {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_integral_v<T>) {
return x;
} else {
return ceil(x);
}
}
};
struct Conjugate {
__device__ hipFloatComplex operator()(hipFloatComplex x) {
return {hipCrealf(x), -hipCimagf(x)};
}
};
struct Cos {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
return {
cos(hipCrealf(x)) * cosh(hipCimagf(x)),
-sin(hipCrealf(x)) * sinh(hipCimagf(x))};
} else {
return cos(x);
}
}
};
struct Cosh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
return {
cosh(hipCrealf(x)) * cos(hipCimagf(x)),
sinh(hipCrealf(x)) * sin(hipCimagf(x))};
} else {
return cosh(x);
}
}
};
struct Erf {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, __half>) {
return erf(__half2float(x));
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return erf(__bfloat162float(x));
} else {
return erf(x);
}
}
};
struct ErfInv {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, __half>) {
return erfinv(__half2float(x));
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return erfinv(__bfloat162float(x));
} else {
return erfinv(x);
}
}
};
struct Exp {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
auto m = exp(hipCrealf(x));
return {m * cos(hipCimagf(x)), m * sinh(hipCimagf(x))};
} else {
return exp(x);
}
}
};
struct Expm1 {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, __half>) {
return expm1(__half2float(x));
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return expm1(__bfloat162float(x));
} else {
return expm1(x);
}
}
};
struct Floor {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_integral_v<T>) {
return x;
} else {
return floor(x);
}
}
};
struct Imag {
__device__ float operator()(hipFloatComplex x) {
return hipCimagf(x);
}
};
struct Log {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
auto r = log(hipCrealf(Abs{}(x)));
auto i = atan2f(hipCimagf(x), hipCrealf(x));
return {r, i};
} else {
return log(x);
}
}
};
struct Log2 {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
auto y = Log{}(x);
return {hipCrealf(y) / M_LN2, hipCimagf(y) / M_LN2};
} else {
return log2(x);
}
}
};
struct Log10 {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
auto y = Log{}(x);
return {hipCrealf(y) / M_LN10, hipCimagf(y) / M_LN10};
} else {
return log10(x);
}
}
};
struct Log1p {
template <typename T>
__device__ T operator()(T x) {
return log1p(x);
}
};
struct LogicalNot {
__device__ bool operator()(bool x) {
return !x;
}
};
struct Negative {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
return 0 - x;
} else {
return -x;
}
}
};
struct Real {
__device__ float operator()(hipFloatComplex x) {
return hipCrealf(x);
}
};
struct Round {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
return {rint(hipCrealf(x)), rint(hipCimagf(x))};
} else {
return rint(x);
}
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
};
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
T y = 1 / (1 + exp(-abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_unsigned_v<T>) {
return x != 0;
} else if constexpr (std::is_same_v<T, hipFloatComplex>) {
if (hipCrealf(x) == 0 && hipCimagf(x) == 0) {
return x;
} else {
return x / Abs()(x);
}
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
} else {
return (x > T(0)) - (x < T(0));
}
}
};
struct Sin {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
return {
sin(hipCrealf(x)) * cosh(hipCimagf(x)),
cos(hipCrealf(x)) * sinh(hipCimagf(x))};
} else {
return sin(x);
}
}
};
struct Sinh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
return {
sinh(hipCrealf(x)) * cos(hipCimagf(x)),
cosh(hipCrealf(x)) * sin(hipCimagf(x))};
} else {
return sinh(x);
}
}
};
struct Square {
template <typename T>
__device__ T operator()(T x) {
return x * x;
}
};
struct Sqrt {
template <typename T>
__device__ T operator()(T x) {
return sqrt(x);
}
};
struct Tan {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
float tan_a = tan(hipCrealf(x));
float tanh_b = tanh(hipCimagf(x));
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
} else {
return tan(x);
}
}
};
struct Tanh {
template <typename T>
__device__ T operator()(T x) {
if constexpr (std::is_same_v<T, hipFloatComplex>) {
float tanh_a = tanh(hipCrealf(x));
float tan_b = tan(hipCimagf(x));
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
} else {
return tanh(x);
}
}
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,173 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_complex.h>
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
// HIP/ROCm type definitions
using hip_complex = hipFloatComplex;
// Utility functions for HIP device code
template <typename T>
struct hip_type {
using type = T;
};
template <>
struct hip_type<bool> {
using type = bool;
};
template <>
struct hip_type<int8_t> {
using type = int8_t;
};
template <>
struct hip_type<uint8_t> {
using type = uint8_t;
};
template <>
struct hip_type<int16_t> {
using type = int16_t;
};
template <>
struct hip_type<uint16_t> {
using type = uint16_t;
};
template <>
struct hip_type<int32_t> {
using type = int32_t;
};
template <>
struct hip_type<uint32_t> {
using type = uint32_t;
};
template <>
struct hip_type<int64_t> {
using type = int64_t;
};
template <>
struct hip_type<uint64_t> {
using type = uint64_t;
};
template <>
struct hip_type<float> {
using type = float;
};
template <>
struct hip_type<double> {
using type = double;
};
#ifdef __HIP_PLATFORM_HCC__
template <>
struct hip_type<__half> {
using type = __half;
};
template <>
struct hip_type<__hip_bfloat16> {
using type = __hip_bfloat16;
};
#endif
template <typename T>
using hip_type_t = typename hip_type<T>::type;
// Element-wise operations support
template <typename T>
constexpr bool is_floating_point_v = std::is_floating_point_v<T>;
template <typename T>
constexpr bool is_integral_v = std::is_integral_v<T>;
template <typename T>
constexpr bool is_signed_v = std::is_signed_v<T>;
template <typename T>
constexpr bool is_unsigned_v = std::is_unsigned_v<T>;
// Complex number helper functions
inline __device__ hipFloatComplex make_complex(float real, float imag) {
return make_hipFloatComplex(real, imag);
}
inline __device__ float hip_real(hipFloatComplex z) {
return hipCrealf(z);
}
inline __device__ float hip_imag(hipFloatComplex z) {
return hipCimagf(z);
}
inline __device__ hipFloatComplex hip_conj(hipFloatComplex z) {
return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z));
}
inline __device__ float hip_abs(hipFloatComplex z) {
return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z));
}
// Memory access utilities
template <typename T>
inline __device__ T hip_load_global(const T* ptr) {
return *ptr;
}
template <typename T>
inline __device__ void hip_store_global(T* ptr, T value) {
*ptr = value;
}
// Grid and block utilities
inline __device__ int hip_thread_idx() {
return threadIdx.x;
}
inline __device__ int hip_block_idx() {
return blockIdx.x;
}
inline __device__ int hip_block_dim() {
return blockDim.x;
}
inline __device__ int hip_grid_dim() {
return gridDim.x;
}
inline __device__ int hip_global_thread_idx() {
return blockIdx.x * blockDim.x + threadIdx.x;
}
// Synchronization
inline __device__ void hip_sync_threads() {
__syncthreads();
}
// Math constants for HIP (equivalent to CUDA's math_constants.h)
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#ifndef M_LN2
#define M_LN2 0.693147180559945309417
#endif
#ifndef M_LN10
#define M_LN10 2.302585092994045684018
#endif
} // namespace mlx::core::rocm

11
mlx/backend/rocm/eval.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
void eval() {
// Placeholder for ROCm evaluation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,50 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/event.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
HipEvent::HipEvent() {
CHECK_HIP_ERROR(hipEventCreate(&event_));
}
HipEvent::~HipEvent() {
CHECK_HIP_ERROR(hipEventDestroy(event_));
}
void HipEvent::record(hipStream_t stream) {
CHECK_HIP_ERROR(hipEventRecord(event_, stream));
}
void HipEvent::wait() {
CHECK_HIP_ERROR(hipEventSynchronize(event_));
}
bool HipEvent::query() const {
hipError_t status = hipEventQuery(event_);
if (status == hipSuccess) {
return true;
} else if (status == hipErrorNotReady) {
return false;
} else {
CHECK_HIP_ERROR(status);
return false;
}
}
SharedEvent::SharedEvent() = default;
void SharedEvent::notify() {
std::lock_guard<std::mutex> lock(mutex_);
ready_ = true;
cv_.notify_one();
}
void SharedEvent::wait() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return ready_; });
ready_ = false;
}
} // namespace mlx::core::rocm

48
mlx/backend/rocm/event.h Normal file
View File

@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <condition_variable>
#include <memory>
#include <mutex>
namespace mlx::core::rocm {
// HIP event managed with RAII.
class HipEvent {
public:
HipEvent();
~HipEvent();
HipEvent(const HipEvent&) = delete;
HipEvent& operator=(const HipEvent&) = delete;
void record(hipStream_t stream);
void wait();
bool query() const;
operator hipEvent_t() const {
return event_;
}
private:
hipEvent_t event_;
};
// Shared event for worker thread synchronization.
class SharedEvent {
public:
SharedEvent();
void notify();
void wait();
private:
std::mutex mutex_;
std::condition_variable cv_;
bool ready_{false};
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,32 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
class Event {
public:
Event() {
check_hip_error("hipEventCreate", hipEventCreate(&event_));
}
~Event() {
hipEventDestroy(event_);
}
void record(hipStream_t stream) {
check_hip_error("hipEventRecord", hipEventRecord(event_, stream));
}
void wait() {
check_hip_error("hipEventSynchronize", hipEventSynchronize(event_));
}
hipEvent_t event() const { return event_; }
private:
hipEvent_t event_;
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void fence() {
// Placeholder for ROCm fence operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void index() {
// Placeholder for ROCm indexing operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,153 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <cstdint>
namespace mlx::core::rocm {
template <typename IdxType>
struct GeneralIterator {
using difference_type = ptrdiff_t;
using value_type = IdxType;
using pointer = IdxType*;
using reference = IdxType&;
using iterator_category = std::random_access_iterator_tag;
const IdxType* base_ptr;
IdxType offset;
const int* shape;
const size_t* strides;
int ndim;
size_t size;
__device__ GeneralIterator(
const IdxType* base_ptr,
IdxType offset,
const int* shape,
const size_t* strides,
int ndim,
size_t size)
: base_ptr(base_ptr),
offset(offset),
shape(shape),
strides(strides),
ndim(ndim),
size(size) {}
__device__ GeneralIterator operator+(difference_type n) const {
return GeneralIterator(base_ptr, offset + n, shape, strides, ndim, size);
}
__device__ GeneralIterator operator-(difference_type n) const {
return GeneralIterator(base_ptr, offset - n, shape, strides, ndim, size);
}
__device__ difference_type operator-(const GeneralIterator& other) const {
return offset - other.offset;
}
__device__ GeneralIterator& operator+=(difference_type n) {
offset += n;
return *this;
}
__device__ GeneralIterator& operator-=(difference_type n) {
offset -= n;
return *this;
}
__device__ GeneralIterator& operator++() {
++offset;
return *this;
}
__device__ GeneralIterator operator++(int) {
GeneralIterator temp = *this;
++offset;
return temp;
}
__device__ GeneralIterator& operator--() {
--offset;
return *this;
}
__device__ GeneralIterator operator--(int) {
GeneralIterator temp = *this;
--offset;
return temp;
}
__device__ bool operator==(const GeneralIterator& other) const {
return offset == other.offset;
}
__device__ bool operator!=(const GeneralIterator& other) const {
return offset != other.offset;
}
__device__ bool operator<(const GeneralIterator& other) const {
return offset < other.offset;
}
__device__ bool operator>(const GeneralIterator& other) const {
return offset > other.offset;
}
__device__ bool operator<=(const GeneralIterator& other) const {
return offset <= other.offset;
}
__device__ bool operator>=(const GeneralIterator& other) const {
return offset >= other.offset;
}
__device__ IdxType operator*() const {
return base_ptr[elem_to_loc(offset, shape, strides, ndim)];
}
__device__ IdxType operator[](difference_type n) const {
return base_ptr[elem_to_loc(offset + n, shape, strides, ndim)];
}
private:
__device__ size_t elem_to_loc(
size_t elem,
const int* shape,
const size_t* strides,
int ndim) const {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
auto q_and_r = div(elem, static_cast<size_t>(shape[i]));
loc += q_and_r.rem * strides[i];
elem = q_and_r.quot;
}
return loc;
}
__device__ div_t div(size_t numer, size_t denom) const {
div_t result;
result.quot = numer / denom;
result.rem = numer % denom;
return result;
}
};
template <typename IdxType>
__device__ std::pair<GeneralIterator<IdxType>, GeneralIterator<IdxType>>
make_general_iterators(
const IdxType* base_ptr,
size_t size,
const int* shape,
const size_t* strides,
int ndim) {
auto begin =
GeneralIterator<IdxType>(base_ptr, 0, shape, strides, ndim, size);
auto end =
GeneralIterator<IdxType>(base_ptr, size, shape, strides, ndim, size);
return std::make_pair(begin, end);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,106 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <cstdint>
namespace mlx::core::rocm {
template <typename T>
struct StridedIterator {
using difference_type = ptrdiff_t;
using value_type = T;
using pointer = T*;
using reference = T&;
using iterator_category = std::random_access_iterator_tag;
T* ptr;
size_t stride;
__device__ StridedIterator(T* ptr, size_t stride)
: ptr(ptr), stride(stride) {}
__device__ StridedIterator operator+(difference_type n) const {
return StridedIterator(ptr + n * stride, stride);
}
__device__ StridedIterator operator-(difference_type n) const {
return StridedIterator(ptr - n * stride, stride);
}
__device__ difference_type operator-(const StridedIterator& other) const {
return (ptr - other.ptr) / stride;
}
__device__ StridedIterator& operator+=(difference_type n) {
ptr += n * stride;
return *this;
}
__device__ StridedIterator& operator-=(difference_type n) {
ptr -= n * stride;
return *this;
}
__device__ StridedIterator& operator++() {
ptr += stride;
return *this;
}
__device__ StridedIterator operator++(int) {
StridedIterator temp = *this;
ptr += stride;
return temp;
}
__device__ StridedIterator& operator--() {
ptr -= stride;
return *this;
}
__device__ StridedIterator operator--(int) {
StridedIterator temp = *this;
ptr -= stride;
return temp;
}
__device__ bool operator==(const StridedIterator& other) const {
return ptr == other.ptr;
}
__device__ bool operator!=(const StridedIterator& other) const {
return ptr != other.ptr;
}
__device__ bool operator<(const StridedIterator& other) const {
return ptr < other.ptr;
}
__device__ bool operator>(const StridedIterator& other) const {
return ptr > other.ptr;
}
__device__ bool operator<=(const StridedIterator& other) const {
return ptr <= other.ptr;
}
__device__ bool operator>=(const StridedIterator& other) const {
return ptr >= other.ptr;
}
__device__ T& operator*() const {
return *ptr;
}
__device__ T& operator[](difference_type n) const {
return *(ptr + n * stride);
}
};
template <typename T>
__device__ StridedIterator<T> make_strided_iterator(T* ptr, size_t stride) {
return StridedIterator<T>(ptr, stride);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,167 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/jit_module.h"
#include "mlx/backend/rocm/utils.h"
#include <fmt/format.h>
#include <mutex>
#include <sstream>
namespace mlx::core::rocm {
JitModule::JitModule(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags,
bool verbose) {
compile(kernel_name, kernel_source, template_args, compiler_flags, verbose);
}
JitModule::~JitModule() {
if (kernel_) {
// No hipFunctionDestroy equivalent in HIP
}
if (module_) {
CHECK_HIP_ERROR(hipModuleUnload(module_));
}
if (program_) {
hiprtcDestroyProgram(&program_);
}
}
void JitModule::compile(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags,
bool verbose) {
// Create HIPRTC program
CHECK_HIP_ERROR(hiprtcCreateProgram(
&program_,
kernel_source.c_str(),
kernel_name.c_str(),
0,
nullptr,
nullptr));
// Build compiler options
std::vector<const char*> options;
std::vector<std::string> option_strings;
// Add default options
option_strings.push_back("--std=c++17");
option_strings.push_back("-O3");
option_strings.push_back("-DMLX_USE_ROCM");
// Add user-provided flags
for (const auto& flag : compiler_flags) {
option_strings.push_back(flag);
}
// Add template arguments
for (const auto& arg : template_args) {
option_strings.push_back("-D" + arg);
}
// Convert to char* array
for (const auto& option : option_strings) {
options.push_back(option.c_str());
}
// Compile the program
hiprtcResult compile_result =
hiprtcCompileProgram(program_, options.size(), options.data());
// Get compilation log
size_t log_size;
CHECK_HIP_ERROR(hiprtcGetProgramLogSize(program_, &log_size));
if (log_size > 1) {
std::vector<char> log(log_size);
CHECK_HIP_ERROR(hiprtcGetProgramLog(program_, log.data()));
if (verbose || compile_result != HIPRTC_SUCCESS) {
fmt::print(
"HIPRTC compilation log for {}:\n{}\n", kernel_name, log.data());
}
}
if (compile_result != HIPRTC_SUCCESS) {
throw std::runtime_error(
fmt::format("HIPRTC compilation failed for kernel {}", kernel_name));
}
// Get compiled code
size_t code_size;
CHECK_HIP_ERROR(hiprtcGetCodeSize(program_, &code_size));
std::vector<char> code(code_size);
CHECK_HIP_ERROR(hiprtcGetCode(program_, code.data()));
// Load module
CHECK_HIP_ERROR(hipModuleLoadData(&module_, code.data()));
// Get kernel function
CHECK_HIP_ERROR(hipModuleGetFunction(&kernel_, module_, kernel_name.c_str()));
}
JitCache& JitCache::instance() {
static JitCache cache;
return cache;
}
std::shared_ptr<JitModule> JitCache::get_or_create(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) {
std::string key =
make_key(kernel_name, kernel_source, template_args, compiler_flags);
std::lock_guard<std::mutex> lock(mutex_);
auto it = cache_.find(key);
if (it != cache_.end()) {
if (auto module = it->second.lock()) {
return module;
} else {
cache_.erase(it);
}
}
auto module = std::make_shared<JitModule>(
kernel_name, kernel_source, template_args, compiler_flags);
cache_[key] = module;
return module;
}
std::string JitCache::make_key(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) const {
std::ostringstream oss;
oss << kernel_name << "|" << kernel_source;
for (const auto& arg : template_args) {
oss << "|" << arg;
}
for (const auto& flag : compiler_flags) {
oss << "|" << flag;
}
return oss.str();
}
std::shared_ptr<JitModule> make_jit_kernel(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) {
return JitCache::instance().get_or_create(
kernel_name, kernel_source, template_args, compiler_flags);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,100 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace mlx::core::rocm {
// JIT compilation module for ROCm
class JitModule {
public:
JitModule(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args = {},
const std::vector<std::string>& compiler_flags = {},
bool verbose = false);
~JitModule();
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
// Get the compiled kernel function
hipFunction_t get_kernel() const {
return kernel_;
}
// Launch the kernel with given arguments
template <typename... Args>
void launch(
dim3 grid_dims,
dim3 block_dims,
size_t shared_memory,
hipStream_t stream,
Args&&... args) {
void* kernel_args[] = {(void*)&args...};
CHECK_HIP_ERROR(hipModuleLaunchKernel(
kernel_,
grid_dims.x,
grid_dims.y,
grid_dims.z,
block_dims.x,
block_dims.y,
block_dims.z,
shared_memory,
stream,
kernel_args,
nullptr));
}
private:
void compile(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags,
bool verbose);
hiprtcProgram program_{nullptr};
hipModule_t module_{nullptr};
hipFunction_t kernel_{nullptr};
};
// JIT cache for compiled modules
class JitCache {
public:
static JitCache& instance();
std::shared_ptr<JitModule> get_or_create(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args = {},
const std::vector<std::string>& compiler_flags = {});
private:
std::unordered_map<std::string, std::weak_ptr<JitModule>> cache_;
std::mutex mutex_;
std::string make_key(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) const;
};
// Helper function to create and cache JIT modules
std::shared_ptr<JitModule> make_jit_kernel(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args = {},
const std::vector<std::string>& compiler_flags = {});
} // namespace mlx::core::rocm

View File

@ -0,0 +1,29 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
// Utility functions for HIP kernels
__device__ inline int get_global_id() {
return blockIdx.x * blockDim.x + threadIdx.x;
}
__device__ inline int get_local_id() {
return threadIdx.x;
}
__device__ inline int get_group_id() {
return blockIdx.x;
}
__device__ inline int get_local_size() {
return blockDim.x;
}
__device__ inline int get_num_groups() {
return gridDim.x;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,135 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <array>
namespace mlx::core::rocm {
// Constants
constexpr int MAX_DIMS = 8;
// HIP array type for passing arrays to kernels
template <typename T, int N>
using hip_array = std::array<T, N>;
// Helper to create hip_array from vector
template <int N, typename T>
__host__ hip_array<T, N> make_hip_array(const std::vector<T>& vec) {
hip_array<T, N> arr;
for (int i = 0; i < N && i < vec.size(); ++i) {
arr[i] = vec[i];
}
return arr;
}
template <typename T>
__host__ hip_array<T, MAX_DIMS> make_hip_array(const std::vector<T>& vec) {
return make_hip_array<MAX_DIMS>(vec);
}
// Type mapping from MLX types to HIP types
template <typename T>
using hip_type_t = T;
template <>
using hip_type_t<float16> = __half;
template <>
using hip_type_t<bfloat16> = __hip_bfloat16;
template <>
using hip_type_t<complex64> = hipFloatComplex;
// Element to location mapping for general broadcasting
template <int NDIM>
__device__ std::pair<int64_t, int64_t> elem_to_loc_nd(
int64_t elem,
const int32_t* shape,
const int64_t* a_strides,
const int64_t* b_strides) {
int64_t a_idx = 0;
int64_t b_idx = 0;
for (int i = NDIM - 1; i >= 0; --i) {
int64_t pos_in_dim = elem % shape[i];
elem /= shape[i];
a_idx += pos_in_dim * a_strides[i];
b_idx += pos_in_dim * b_strides[i];
}
return {a_idx, b_idx};
}
// 4D specialization for performance
__device__ inline std::pair<int64_t, int64_t> elem_to_loc_4d(
int64_t elem,
const int32_t* shape,
const int64_t* a_strides,
const int64_t* b_strides,
int ndim) {
int64_t a_idx = 0;
int64_t b_idx = 0;
for (int i = ndim - 1; i >= 0; --i) {
int64_t pos_in_dim = elem % shape[i];
elem /= shape[i];
a_idx += pos_in_dim * a_strides[i];
b_idx += pos_in_dim * b_strides[i];
}
return {a_idx, b_idx};
}
// Launch configuration calculation
template <typename Kernel>
std::pair<dim3, dim3>
get_launch_args(Kernel kernel, const array& out, bool large = false) {
int threads_per_block = 256;
int64_t total_threads = out.size();
if (large) {
// For large arrays, use more blocks
int64_t blocks =
(total_threads + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
} else {
int blocks = (total_threads + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
}
}
template <typename Kernel>
std::pair<dim3, dim3> get_launch_args(
Kernel kernel,
int64_t size,
const std::vector<int>& shape,
const std::vector<size_t>& strides,
bool large = false) {
int threads_per_block = 256;
if (large) {
int64_t blocks = (size + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
} else {
int blocks = (size + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
}
}
// Cooperative groups thread rank equivalent
namespace cooperative_groups {
class grid_group {
public:
__device__ int64_t thread_rank() const {
return blockIdx.x * blockDim.x + threadIdx.x;
}
};
__device__ grid_group this_grid() {
return grid_group{};
}
} // namespace cooperative_groups
} // namespace mlx::core::rocm

View File

@ -0,0 +1,437 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/iterators/strided_iterator.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/backend/rocm/reduce/reduce.hpp"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <hip/hip_runtime.h>
#include <hip/hip_cooperative_groups.h>
#include <rocprim/block/block_load.hpp>
#include <rocprim/block/block_reduce.hpp>
namespace mlx::core {
namespace rocm {
namespace cg = cooperative_groups;
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
return {a.x + b.x, a.y + b.y, a.z + b.z};
}
// Similar to rocprim::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM>
struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block);
T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
return Reduce(input, hip_plus<T>{}, T{});
}
};
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void layer_norm(
const T* x,
const T* w,
const T* b,
T* out,
float eps,
int32_t axis_size,
int64_t w_stride,
int64_t b_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
x += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
// Sum.
float sum = 0;
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
sum += static_cast<float>(rocprim::thread_reduce(xn, hip_plus<T>{}));
}
sum = BlockReduceT{block, temp}.Sum(sum);
// Mean.
float mean = sum / axis_size;
// Normalizer.
float normalizer = 0;
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean);
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T bn[N_READS];
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
rocprim::block_load_direct_blocked(index, strided_iterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
}
rocprim::block_store_direct_blocked(index, out, xn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void layer_norm_vjp(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF3::TempStorage f3;
} temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
gx += grid.block_rank() * axis_size;
gw += grid.block_rank() * axis_size;
// Sum.
float sum = 0;
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
sum += static_cast<float>(rocprim::thread_reduce(xn, hip_plus<T>{}));
}
sum = BlockReduceF{block, temp.f}.Sum(sum);
// Mean.
float mean = sum / axis_size;
// Normalizer.
float3 factors = {};
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean);
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
}
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
float meanwg = factors.x / axis_size;
float meanwgxc = factors.y / axis_size;
float normalizer2 = 1 / (factors.z / axis_size + eps);
float normalizer = sqrt(normalizer2);
// Outputs.
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i];
float gi = gn[i];
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
if constexpr (HAS_W) {
wn[i] = gi * xi;
}
}
rocprim::block_store_direct_blocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
rocprim::block_store_direct_blocked(index, gw, wn, axis_size);
}
}
}
// Utility functions
template <typename T>
struct hip_plus {
__device__ T operator()(const T& a, const T& b) const {
return a + b;
}
};
inline __device__ int hip_ceil_div(int a, int b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ inline auto strided_iterator(const T* ptr, int64_t stride) {
return ptr + stride; // Simplified strided iterator
}
} // namespace rocm
namespace fast {
bool LayerNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
// TODO: There are duplicate code with backend/metal/normalization.cpp
void LayerNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& out = outputs[0];
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
const array x = set_output(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
using DataType = hip_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = rocm::layer_norm<DataType, BLOCK_DIM, N_READS>;
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
eps_,
axis_size,
w_stride,
b_stride);
});
});
});
}
void LayerNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& encoder = rocm::get_command_encoder(s);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
const array& b = inputs[2];
auto [g, g_copied] = check_input(inputs[3]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
array& gb = outputs[2];
// Check whether we had a weight.
bool has_w = w.ndim() != 0;
// Allocate space for the outputs.
bool g_in_gx = false;
if (donate_x) {
gx.copy_shared_buffer(x);
} else if (donate_g) {
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
}
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
if (has_w) {
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
encoder.add_temporary(gw_temp);
}
}
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// Finish with the gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
}
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(g);
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
using DataType = hip_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = rocm::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
});
if (has_w) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
}
}
} // namespace fast
} // namespace mlx::core
namespace mlx::core::rocm {
__global__ void layer_norm_kernel(
float* input,
float* output,
float* gamma,
float* beta,
int n,
float eps) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// Simplified layer norm placeholder
// Real implementation would compute mean and variance
output[idx] = gamma[idx] * input[idx] + beta[idx];
}
}
void launch_layer_norm(
float* input,
float* output,
float* gamma,
float* beta,
int n,
float eps,
hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(layer_norm_kernel, dim3(blocks), dim3(threads), 0, stream,
input, output, gamma, beta, n, eps);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,13 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void logsumexp_kernel(float* input, float* output, int n) {
// Placeholder implementation
int idx = blockIdx.x * blockDim.x + threadIdx.x;
(void)input; (void)output; (void)n; (void)idx;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,30 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
void matmul_hip(
float* a,
float* b,
float* c,
int m,
int n,
int k,
hipStream_t stream) {
// This is a placeholder - in a real implementation, this would use rocBLAS
// auto& device = get_current_device();
// rocblas_sgemm(device.rocblas_handle(), ...);
// For now, just a placeholder
(void)a;
(void)b;
(void)c;
(void)m;
(void)n;
(void)k;
(void)stream;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/rocm.h"
namespace mlx::core::rocm {
bool is_available() {
return false;
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,21 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/common/primitives.h"
namespace mlx::core::rocm {
// Basic kernel implementations will go here
// This is a placeholder for ROCm-specific primitive operations
void add_hip() {
// Placeholder for HIP add operation
}
void multiply_hip() {
// Placeholder for HIP multiply operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,23 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void random_uniform_kernel(float* output, int n, unsigned int seed) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// Simple LCG placeholder - real implementation would use rocRAND
unsigned int state = seed + idx;
state = state * 1103515245 + 12345;
output[idx] = (float)(state & 0x7FFFFFFF) / (float)0x7FFFFFFF;
}
}
void launch_random_uniform(float* output, int n, unsigned int seed, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(random_uniform_kernel, dim3(blocks), dim3(threads), 0, stream, output, n, seed);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,24 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
__global__ void sum_reduce_kernel(float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Simple reduction placeholder
if (idx == 0) {
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += input[i];
}
output[0] = sum;
}
}
void launch_sum_reduce(float* input, float* output, int n, hipStream_t stream) {
hipLaunchKernelGGL(sum_reduce_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,311 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/device/cast_op.hpp"
#include "mlx/backend/rocm/reduce/reduce.hpp"
#include <hip/hip_runtime.h>
#include <hip/hip_cooperative_groups.h>
#include <rocprim/block/block_load.hpp>
namespace mlx::core {
namespace rocm {
namespace cg = cooperative_groups;
struct ColReduceArgs {
// The size of the contiguous column reduction.
size_t reduction_size;
int64_t reduction_stride;
// Input shape and strides excluding the reduction axes.
Shape shape;
Strides strides;
int ndim;
// Input shape and strides of the reduction axes (including last dimension).
Shape reduce_shape;
Strides reduce_strides;
int reduce_ndim;
// The number of column we are reducing. Namely prod(reduce_shape).
size_t non_col_reductions;
ColReduceArgs(
const array& in,
const ReductionPlan& plan,
const std::vector<int>& axes) {
assert(!plan.shape.empty());
reduction_size = plan.shape.back();
reduction_stride = plan.strides.back();
int64_t stride_back = 1;
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
while (!shape_vec.empty() && stride_back < reduction_stride) {
stride_back *= shape_vec.back();
shape_vec.pop_back();
strides_vec.pop_back();
}
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
reduce_shape = const_param(plan.shape);
reduce_strides = const_param(plan.strides);
reduce_ndim = plan.shape.size();
non_col_reductions = 1;
for (int i = 0; i < reduce_ndim - 1; i++) {
non_col_reductions *= reduce_shape[i];
}
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
int column =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
if (column * N_READS >= args.reduction_stride) {
return;
}
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(
block.thread_index().y,
args.reduce_shape.data(),
args.reduce_strides.data());
for (size_t r = block.thread_index().y;
r < args.non_col_reductions * args.reduction_size;
r += block.dim_threads().y) {
U vals[N_READS];
rocprim::block_load_direct_blocked(
column,
make_cast_iterator<U>(in + loop.location()),
vals,
args.reduction_stride,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(
block.dim_threads().y,
args.reduce_shape.data(),
args.reduce_strides.data());
}
// Do block reduce when each column has more than 1 element to reduce.
if (block.dim_threads().y > 1) {
__shared__ U shared_vals[32 * 8 * N_READS];
size_t col =
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
shared_vals[col * N_READS + i] = totals[i];
}
block.sync();
if (block.thread_index().y == 0) {
for (int i = 0; i < N_READS; i++) {
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
}
for (int j = 1; j < block.dim_threads().y; j++) {
col = j * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
}
}
}
}
// Write result.
if (block.thread_index().y == 0) {
rocprim::block_store_direct_blocked(
column,
out + out_idx * args.reduction_stride,
totals,
args.reduction_stride);
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BM,
int BN,
int N_READS = 4>
__global__ void col_reduce_looped(
const T* in,
U* out,
const ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int n_warps = BN / N_READS;
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
int r = block.thread_rank() / n_warps;
int column = block.thread_rank() % n_warps;
int in_offset = grid.block_index().x * BN;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
U vals[N_READS];
rocprim::block_load_direct_blocked(
column,
make_cast_iterator<U>(in + loop.location() + in_offset),
vals,
args.reduction_stride - in_offset,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
// Do warp reduce for each output.
constexpr int n_outputs = BN / n_warps;
static_assert(BM == 32 && n_outputs == N_READS);
__shared__ U shared_vals[BM * BN];
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
for (int i = 0; i < N_READS; i++) {
shared_vals[col + i] = totals[i];
}
block.sync();
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
for (int i = 0; i < n_outputs; i++) {
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
}
// Write result.
if (warp.thread_rank() == 0) {
size_t out_offset = grid.block_index().x * BN;
rocprim::block_store_direct_blocked(
warp.meta_group_rank(),
out + out_idx * args.reduction_stride + out_offset,
totals,
args.reduction_stride - out_offset);
}
}
// Utility functions and templates
template <int NDIM, bool USE_FAST_PATH>
struct LoopedElemToLoc {
size_t location;
__device__ LoopedElemToLoc(int reduce_ndim) : location(0) {}
__device__ void next(size_t step, const int* shape, const size_t* strides) {
// Simplified implementation - actual would handle multi-dimensional indexing
location += step;
}
};
template <typename T>
__device__ inline T* make_cast_iterator(const T* ptr) {
return const_cast<T*>(ptr);
}
__device__ inline size_t elem_to_loc(
size_t elem,
const int* shape,
const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
size_t q = elem / shape[i];
size_t r = elem % shape[i];
loc += r * strides[i];
elem = q;
}
return loc;
}
} // namespace rocm
inline auto output_grid_for_col_reduce(
const array& out,
const rocm::ColReduceArgs& args) {
auto out_shape = out.shape();
auto out_strides = out.strides();
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
out_shape.pop_back();
out_strides.pop_back();
}
return get_2d_grid_dims(out_shape, out_strides);
}
void col_reduce(
rocm::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
rocm::ColReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = hip_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = rocm::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr int N_READS = 4;
dim3 block_dims;
dim3 num_blocks = output_grid_for_col_reduce(out, args);
num_blocks.z = num_blocks.y;
num_blocks.y = num_blocks.x;
auto kernel =
rocm::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
size_t total = args.non_col_reductions * args.reduction_size;
if (total < 32) {
size_t stride_blocks =
hip_ceil_div(args.reduction_stride, N_READS);
block_dims.x = std::min(stride_blocks, 32ul);
block_dims.y = std::min(total, 8ul);
num_blocks.x = hip_ceil_div(stride_blocks, block_dims.x);
} else {
constexpr int BM = 32;
constexpr int BN = 32;
block_dims.x = BM * BN / N_READS;
num_blocks.x = hip_ceil_div(args.reduction_stride, BN);
kernel = rocm::
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
}
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
in.data<InType>(), out.data<OutType>(), args);
});
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,119 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <cstddef>
namespace mlx::core::rocm {
// Reduction operation types
template <typename Op, typename T>
struct ReduceInit {
static constexpr T value();
};
template <typename T>
struct ReduceInit<struct Sum, T> {
static constexpr T value() {
return T(0);
}
};
template <typename T>
struct ReduceInit<struct Max, T> {
static constexpr T value() {
return -std::numeric_limits<T>::infinity();
}
};
template <typename T>
struct ReduceInit<struct Min, T> {
static constexpr T value() {
return std::numeric_limits<T>::infinity();
}
};
// Reduction operations
struct Sum {
template <typename T>
__device__ T operator()(T a, T b) const {
return a + b;
}
};
struct Max {
template <typename T>
__device__ T operator()(T a, T b) const {
return fmax(a, b);
}
};
struct Min {
template <typename T>
__device__ T operator()(T a, T b) const {
return fmin(a, b);
}
};
struct Prod {
template <typename T>
__device__ T operator()(T a, T b) const {
return a * b;
}
};
// Utility functions for reductions
template <typename T>
__device__ T warp_reduce(T val, T (*op)(T, T)) {
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val = op(val, __shfl_down(val, offset));
}
return val;
}
template <typename T>
__device__ T block_reduce(T val, T (*op)(T, T)) {
static __shared__ T shared[32];
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;
val = warp_reduce(val, op);
if (lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
if (wid == 0)
val = warp_reduce(val, op);
return val;
}
// Column reduction arguments
struct ColReduceArgs {
size_t reduction_size;
int64_t reduction_stride;
int* shape;
size_t* strides;
int ndim;
int* reduce_shape;
size_t* reduce_strides;
int reduce_ndim;
size_t non_col_reductions;
};
// Row reduction arguments
struct RowReduceArgs {
size_t reduction_size;
int64_t reduction_stride;
int* shape;
size_t* strides;
int ndim;
int* reduce_shape;
size_t* reduce_strides;
int reduce_ndim;
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,375 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/iterators/strided_iterator.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/backend/rocm/reduce/reduce.hpp"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <hip/hip_runtime.h>
#include <hip/hip_cooperative_groups.h>
#include <rocprim/block/block_load.hpp>
#include <rocprim/block/block_reduce.hpp>
namespace mlx::core {
namespace rocm {
namespace cg = cooperative_groups;
// Similar to rocprim::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM>
struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block);
T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
return Reduce(input, hip_plus<T>{}, T{});
}
};
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm(
const T* x,
const T* w,
T* out,
float eps,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
x += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
// Sum of squares.
float sum_sq = 0;
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float val = static_cast<float>(xn[i]);
sum_sq += val * val;
}
}
sum_sq = BlockReduceT{block, temp}.Sum(sum_sq);
// RMS normalizer.
float rms_normalizer = rsqrt(sum_sq / axis_size + eps);
// Outputs.
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float norm = static_cast<float>(xn[i]) * rms_normalizer;
xn[i] = wn[i] * static_cast<T>(norm);
}
rocprim::block_store_direct_blocked(index, out, xn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm_vjp(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF2::TempStorage f2;
} temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
gx += grid.block_rank() * axis_size;
gw += grid.block_rank() * axis_size;
// Sum of squares.
float sum_sq = 0;
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float val = static_cast<float>(xn[i]);
sum_sq += val * val;
}
}
sum_sq = BlockReduceF{block, temp.f}.Sum(sum_sq);
// RMS normalizer.
float rms_normalizer = rsqrt(sum_sq / axis_size + eps);
// Compute gradient terms.
float2 factors = {};
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors.x += wg;
factors.y += wg * xi;
}
}
auto plus_f2 = [] __device__ (const float2& a, const float2& b) -> float2 {
return {a.x + b.x, a.y + b.y};
};
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
float mean_wg = factors.x / axis_size;
float mean_wgx = factors.y / axis_size;
float rms3 = rms_normalizer * rms_normalizer * rms_normalizer;
// Outputs.
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float norm = xi * rms_normalizer;
xn[i] = rms_normalizer * (wi * gi - mean_wg) - norm * mean_wgx * rms3;
if constexpr (HAS_W) {
wn[i] = gi * norm;
}
}
rocprim::block_store_direct_blocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
rocprim::block_store_direct_blocked(index, gw, wn, axis_size);
}
}
}
// Utility functions
template <typename T>
struct hip_plus {
__device__ T operator()(const T& a, const T& b) const {
return a + b;
}
};
inline __device__ int hip_ceil_div(int a, int b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ inline auto strided_iterator(const T* ptr, int64_t stride) {
return ptr + stride; // Simplified strided iterator
}
} // namespace rocm
namespace fast {
bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& out = outputs[0];
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
const array x = set_output(inputs[0]);
const array& w = inputs[1];
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rmsnorm", CTYPE, {
using DataType = hip_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = rocm::rms_norm<DataType, BLOCK_DIM, N_READS>;
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
x.data<DataType>(),
w.data<DataType>(),
out.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
}
void RMSNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& encoder = rocm::get_command_encoder(s);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
auto [g, g_copied] = check_input(inputs[2]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
// Check whether we had a weight.
bool has_w = w.ndim() != 0;
// Allocate space for the outputs.
bool g_in_gx = false;
if (donate_x) {
gx.copy_shared_buffer(x);
} else if (donate_g) {
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
}
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
if (has_w) {
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
encoder.add_temporary(gw_temp);
}
}
gw.set_data(allocator::malloc(gw.nbytes()));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(g);
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rmsnorm_vjp", CTYPE, {
using DataType = hip_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = rocm::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
});
if (has_w) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
}
}
} // namespace fast
} // namespace mlx::core

11
mlx/backend/rocm/rocm.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/rocm.h"
namespace mlx::core::rocm {
bool is_available() {
return true;
}
} // namespace mlx::core::rocm

10
mlx/backend/rocm/rocm.h Normal file
View File

@ -0,0 +1,10 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::rocm {
/* Check if the ROCm backend is available. */
bool is_available();
} // namespace mlx::core::rocm

383
mlx/backend/rocm/rope.hip Normal file
View File

@ -0,0 +1,383 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <hip/hip_runtime.h>
namespace mlx::core {
namespace rocm {
template <typename T, bool traditional, bool forward>
__device__ void rope_single_impl(
const T* in,
T* out,
int32_t offset,
float inv_freq,
float scale,
int64_t stride,
uint2 pos,
uint2 dims) {
float L = scale * static_cast<float>(offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
uint index_1, index_2;
if (traditional) {
index_1 = 2 * pos.x + pos.y * stride;
index_2 = index_1 + 1;
} else {
index_1 = pos.x + pos.y * stride;
index_2 = index_1 + dims.x;
}
// Read and write the output
float x1 = static_cast<float>(in[index_1]);
float x2 = static_cast<float>(in[index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[index_1] = static_cast<T>(rx1);
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
int64_t stride,
uint2 dims) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
int64_t stride,
uint2 dims,
int64_t freq_stride) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
float inv_freq,
float scale,
const hip_array<int64_t, 3> strides,
const hip_array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
}
}
template <typename T, bool traditional, bool forward>
__global__ void rope(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
const hip_array<int64_t, 3> strides,
const hip_array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
float base,
const hip_array<int64_t, 3> strides,
const hip_array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
} // namespace rocm
namespace fast {
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& in = inputs[0];
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
hip_array<int64_t, 3> strides;
hip_array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
// Either copy or apply in-place
else if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
donated = true;
copy_gpu(in, out, CopyType::General, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 3;
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
using DataType = hip_type_t<CTYPE>;
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
MLX_SWITCH_BOOL(forward_, FORWARD, {
if (single && !with_freqs) {
auto kernel = rocm::rope_single<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
mat_size,
dims);
} else if (single) {
auto kernel = rocm::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
mat_size,
dims,
inputs[2].strides(0));
} else if (with_freqs) {
auto kernel = rocm::rope_freqs<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims,
inputs[2].strides(0));
} else {
auto kernel = rocm::rope<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims);
}
});
});
});
});
}
} // namespace fast
} // namespace mlx::core

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::rocm {
void slice() {
// Placeholder for ROCm slicing operation
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,179 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/device/cast_op.hpp"
#include "mlx/backend/rocm/device/fp16_math.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <hip/hip_runtime.h>
#include <hip/hip_cooperative_groups.h>
#include <rocprim/block/block_load.hpp>
#include <cassert>
namespace mlx::core {
namespace rocm {
namespace cg = cooperative_groups;
template <typename T>
inline __device__ T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return __expf(x);
}
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
__global__ void softmax(const T* in, T* out, int axis_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
in += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
// Thread reduce.
AccT prevmax;
AccT maxval = -INFINITY;
AccT normalizer = 0;
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS];
rocprim::block_load_direct_blocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
-INFINITY);
prevmax = maxval;
maxval = fmax(maxval, rocprim::thread_reduce(vals, hip_max<AccT>()));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
// First warp reduce.
prevmax = maxval;
maxval = cg::reduce(warp, maxval, hip_max<AccT>());
normalizer = normalizer * softmax_exp(prevmax - maxval);
normalizer = cg::reduce(warp, normalizer, hip_plus<AccT>());
__shared__ AccT local_max[WARP_SIZE];
__shared__ AccT local_normalizer[WARP_SIZE];
// Write to shared memory and do second warp reduce.
prevmax = maxval;
if (warp.thread_rank() == 0) {
local_max[warp.meta_group_rank()] = maxval;
}
block.sync();
maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()]
: -INFINITY;
maxval = cg::reduce(warp, maxval, hip_max<AccT>());
normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) {
local_normalizer[warp.meta_group_rank()] = normalizer;
}
block.sync();
normalizer = warp.thread_rank() < warp.meta_group_size()
? local_normalizer[warp.thread_rank()]
: AccT{};
normalizer = cg::reduce(warp, normalizer, hip_plus<AccT>());
normalizer = 1 / normalizer;
// Write output.
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank();
T vals[N_READS];
rocprim::block_load_direct_blocked(index, in, vals, axis_size);
for (int i = 0; i < N_READS; i++) {
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
}
rocprim::block_store_direct_blocked(index, out, vals, axis_size);
}
}
// Utility functions for ROCm
template <typename T>
struct hip_max {
__device__ T operator()(const T& a, const T& b) const {
return fmax(a, b);
}
};
template <typename T>
struct hip_plus {
__device__ T operator()(const T& a, const T& b) const {
return a + b;
}
};
inline __device__ int hip_ceil_div(int a, int b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ inline T* make_cast_iterator(const T* ptr) {
return const_cast<T*>(ptr);
}
} // namespace rocm
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& s = stream();
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
array in = set_output(inputs[0]);
bool precise = in.dtype() != float32 && precise_;
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
using DataType = hip_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = rocm::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
if (precise) {
kernel = rocm::softmax<DataType, float, BLOCK_DIM, N_READS>;
}
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
in.data<DataType>(), out.data<DataType>(), axis_size);
});
});
});
}
} // namespace mlx::core

178
mlx/backend/rocm/sort.hip Normal file
View File

@ -0,0 +1,178 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <hip/hip_runtime.h>
#include <rocthrust/device_ptr.h>
#include <rocthrust/transform.h>
#include <rocprim/device/device_segmented_sort.hpp>
#include <cassert>
#include <numeric>
namespace mlx::core {
namespace {
template <typename T>
struct ModOp {
T divisor;
__device__ T operator()(T x) {
return x % divisor;
}
};
// We can not use any op in eval, make an utility.
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
// TODO: Share the code with Transpose::eval.
Shape shape(axes.size());
Strides strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
auto flags = in.flags();
if (flags.contiguous) {
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
}
array out(shape, in.dtype(), nullptr, {});
out.copy_shared_buffer(in, strides, flags, in.data_size());
return out;
}
template <typename... Args>
void segmented_sort_pairs(rocm::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_HIP_ERROR(
rocprim::segmented_sort_pairs(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_HIP_ERROR(rocprim::segmented_sort_pairs(
temp.data<void>(), size, args...));
}
template <typename... Args>
void segmented_sort(rocm::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_HIP_ERROR(
rocprim::segmented_sort_keys(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_HIP_ERROR(rocprim::segmented_sort_keys(
temp.data<void>(), size, args...));
}
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_;
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) {
axis += in.ndim();
}
int nsort = in.shape(axis);
int nsegments = in.data_size() / nsort;
int last_dim = in.ndim() - 1;
// If we are not sorting the innermost dimension of a contiguous array,
// transpose and make a copy.
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = array(trans.shape(), trans.dtype(), nullptr, {});
copy_gpu(trans, in, CopyType::General, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = hip_type_t<CTYPE>;
auto offsets = rocthrust::make_transform_iterator(
rocthrust::make_counting_iterator(0),
[nsort] __device__(int i) { return i * nsort; });
if (argsort) {
// Indices in the sorted dimension.
array indices(
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(indices);
rocthrust::transform(
rocm::thrust_policy(stream),
rocthrust::counting_iterator<uint32_t>(0),
rocthrust::counting_iterator<uint32_t>(indices.data_size()),
rocthrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
encoder.add_temporary(discard);
segmented_sort_pairs(
encoder,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
nsegments,
offsets,
offsets + 1,
stream);
} else {
segmented_sort(
encoder,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
nsegments,
offsets,
offsets + 1,
stream);
}
} else {
throw std::runtime_error(
"ROCm backend does not support sorting complex numbers");
}
});
});
if (!is_segmented_sort) {
// Swap the sorted axis back.
// TODO: Do in-place transpose instead of using a temporary out array.
copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s);
}
}
} // namespace
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
gpu_sort(stream(), inputs[0], out, axis_, true);
}
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
gpu_sort(stream(), inputs[0], out, axis_, false);
}
} // namespace mlx::core

View File

@ -0,0 +1,148 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/device/ternary_ops.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <hip/hip_runtime.h>
#include <rocthrust/device_ptr.h>
#include <rocthrust/transform.h>
namespace mlx::core {
namespace rocm {
template <typename Op, typename Condition, typename A, typename B, typename Out>
constexpr bool supports_ternary_op() {
if (std::is_same_v<Op, Select>) {
return std::is_same_v<Condition, bool> && std::is_same_v<A, Out> && std::is_same_v<B, Out>;
}
return false;
}
} // namespace rocm
template <typename Op>
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
auto& condition = inputs[0];
auto& a = inputs[1];
auto& b = inputs[2];
if (condition.size() == 0) {
return;
}
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(condition);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_ALL_TYPES(condition.dtype(), CONDITION_TYPE, {
MLX_SWITCH_ALL_TYPES(a.dtype(), A_TYPE, {
MLX_SWITCH_ALL_TYPES(b.dtype(), B_TYPE, {
MLX_SWITCH_ALL_TYPES(out.dtype(), OUT_TYPE, {
if constexpr (rocm::supports_ternary_op<Op, CONDITION_TYPE, A_TYPE, B_TYPE, OUT_TYPE>()) {
using ConditionType = hip_type_t<CONDITION_TYPE>;
using AType = hip_type_t<A_TYPE>;
using BType = hip_type_t<B_TYPE>;
using OutType = hip_type_t<OUT_TYPE>;
auto policy = rocm::thrust_policy(stream);
auto condition_ptr = rocthrust::device_pointer_cast(condition.data<ConditionType>());
auto a_ptr = rocthrust::device_pointer_cast(a.data<AType>());
auto b_ptr = rocthrust::device_pointer_cast(b.data<BType>());
auto out_ptr = rocthrust::device_pointer_cast(out.data<OutType>());
if (condition.flags().contiguous && a.flags().contiguous && b.flags().contiguous) {
auto ternary_op = [=] __device__ (const auto& tuple) -> OutType {
return Op{}(rocthrust::get<0>(tuple), rocthrust::get<1>(tuple), rocthrust::get<2>(tuple));
};
auto zip_begin = rocthrust::make_zip_iterator(
rocthrust::make_tuple(condition_ptr, a_ptr, b_ptr));
auto zip_end = rocthrust::make_zip_iterator(
rocthrust::make_tuple(condition_ptr + condition.data_size(),
a_ptr + a.data_size(),
b_ptr + b.data_size()));
rocthrust::transform(policy, zip_begin, zip_end, out_ptr, ternary_op);
} else {
// Handle non-contiguous arrays with general iterators
auto [condition_shape, condition_strides] = collapse_contiguous_dims(condition);
auto [a_shape, a_strides] = collapse_contiguous_dims(a);
auto [b_shape, b_strides] = collapse_contiguous_dims(b);
auto [condition_begin, condition_end] = rocm::make_general_iterators<int64_t>(
condition_ptr, condition.size(), condition_shape, condition_strides);
auto [a_begin, a_end] = rocm::make_general_iterators<int64_t>(
a_ptr, a.size(), a_shape, a_strides);
auto [b_begin, b_end] = rocm::make_general_iterators<int64_t>(
b_ptr, b.size(), b_shape, b_strides);
auto ternary_op = [=] __device__ (const auto& tuple) -> OutType {
return Op{}(rocthrust::get<0>(tuple), rocthrust::get<1>(tuple), rocthrust::get<2>(tuple));
};
auto zip_begin = rocthrust::make_zip_iterator(
rocthrust::make_tuple(condition_begin, a_begin, b_begin));
auto zip_end = rocthrust::make_zip_iterator(
rocthrust::make_tuple(condition_end, a_end, b_end));
rocthrust::transform(policy, zip_begin, zip_end, out_ptr, ternary_op);
}
} else {
throw std::runtime_error(fmt::format(
"Can not do ternary op {} on inputs of {}, {}, {} with output of {}.",
op,
dtype_to_string(condition.dtype()),
dtype_to_string(a.dtype()),
dtype_to_string(b.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
});
});
}
template <typename Op>
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
set_ternary_output_data(inputs, out);
ternary_op_gpu_inplace<Op>(inputs, out, op, s);
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
ternary_op_gpu<rocm::Select>(inputs, out, get_primitive_string(this), s);
}
} // namespace mlx::core
__global__ void select_kernel(float* condition, float* a, float* b, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = (condition[idx] != 0.0f) ? a[idx] : b[idx];
}
}
void launch_select(float* condition, float* a, float* b, float* output, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(select_kernel, dim3(blocks), dim3(threads), 0, stream, condition, a, b, output, n);
}
} // namespace mlx::core::rocm

222
mlx/backend/rocm/unary.hip Normal file
View File

@ -0,0 +1,222 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/device/hip_complex_math.hpp"
#include "mlx/backend/rocm/device/unary_ops.hpp"
#include "mlx/backend/rocm/iterators/general_iterator.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <hip/hip_runtime.h>
#include <rocthrust/device_ptr.h>
#include <rocthrust/transform.h>
namespace mlx::core {
namespace rocm {
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
}
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace rocm
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (rocm::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = hip_type_t<CTYPE_IN>;
using OutType = hip_type_t<CTYPE_OUT>;
auto policy = rocm::thrust_policy(stream);
auto in_ptr = rocthrust::device_pointer_cast(in.data<InType>());
auto out_ptr = rocthrust::device_pointer_cast(out.data<OutType>());
if (in.flags().contiguous) {
rocthrust::transform(
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
} else {
auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = rocm::make_general_iterators<int64_t>(
in_ptr, in.size(), shape, strides);
rocthrust::transform(policy, in_begin, in_end, out_ptr, Op());
}
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
auto& s = out.primitive().stream(); \
unary_op_gpu<rocm::func>(inputs, out, get_primitive_string(this), s); \
}
UNARY_GPU(Abs)
UNARY_GPU(ArcCos)
UNARY_GPU(ArcCosh)
UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Ceil)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Floor)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)
UNARY_GPU(Sinh)
UNARY_GPU(Square)
UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (base_) {
case Base::e:
unary_op_gpu<rocm::Log>(inputs, out, op, s);
break;
case Base::two:
unary_op_gpu<rocm::Log2>(inputs, out, op, s);
break;
case Base::ten:
unary_op_gpu<rocm::Log10>(inputs, out, op, s);
break;
}
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<rocm::Round>(inputs, out, get_primitive_string(this), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<rocm::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<rocm::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core
__global__ void relu_kernel(float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = fmaxf(0.0f, input[idx]);
}
}
__global__ void sigmoid_kernel(float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = 1.0f / (1.0f + expf(-input[idx]));
}
}
void launch_relu(float* input, float* output, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(relu_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
}
void launch_sigmoid(float* input, float* output, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(sigmoid_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,46 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/rocm/device.h"
#include "mlx/dtype_utils.h"
#include <fmt/format.h>
namespace mlx::core {
HipStream::HipStream(rocm::Device& device) {
device.make_current();
CHECK_HIP_ERROR(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking));
}
HipStream::~HipStream() {
CHECK_HIP_ERROR(hipStreamDestroy(stream_));
}
void check_hip_error(const char* name, hipError_t err) {
if (err != hipSuccess) {
throw std::runtime_error(
fmt::format("{} failed: {}", name, hipGetErrorString(err)));
}
}
const char* dtype_to_hip_type(const Dtype& dtype) {
if (dtype == float16) {
return "__half";
}
if (dtype == bfloat16) {
return "__hip_bfloat16";
}
if (dtype == complex64) {
return "hipFloatComplex";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (dtype == DTYPE) { \
return #CPP_TYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return nullptr;
}
} // namespace mlx::core

43
mlx/backend/rocm/utils.h Normal file
View File

@ -0,0 +1,43 @@
// Copyright © 2025 Apple Inc.
// This file includes utilities that are used by C++ code (i.e. .cpp files).
#pragma once
#include <hip/hip_runtime.h>
namespace mlx::core {
namespace rocm {
class Device;
}
struct Dtype;
// HIP stream managed with RAII.
class HipStream {
public:
explicit HipStream(rocm::Device& device);
~HipStream();
HipStream(const HipStream&) = delete;
HipStream& operator=(const HipStream&) = delete;
operator hipStream_t() const {
return stream_;
}
private:
hipStream_t stream_;
};
// Throw exception if the HIP API does not succeed.
void check_hip_error(const char* name, hipError_t err);
// The macro version that prints the command that failed.
#define CHECK_HIP_ERROR(cmd) check_hip_error(#cmd, (cmd))
// Convert Dtype to HIP C++ types.
const char* dtype_to_hip_type(const Dtype& dtype);
} // namespace mlx::core

View File

@ -0,0 +1,76 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/worker.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
Worker::Worker() : worker_thread_(&Worker::worker_loop, this) {}
Worker::~Worker() {
{
std::lock_guard<std::mutex> lock(mutex_);
stop_ = true;
}
cv_.notify_all();
if (worker_thread_.joinable()) {
worker_thread_.join();
}
}
void Worker::add_task(std::function<void()> task) {
{
std::lock_guard<std::mutex> lock(mutex_);
tasks_.push(task);
}
cv_.notify_one();
}
void Worker::consume_in_this_thread() {
std::queue<std::function<void()>> local_tasks;
{
std::lock_guard<std::mutex> lock(mutex_);
local_tasks.swap(tasks_);
}
while (!local_tasks.empty()) {
auto task = local_tasks.front();
local_tasks.pop();
task();
}
}
void Worker::commit(hipStream_t stream) {
// Synchronize with stream and then process tasks
CHECK_HIP_ERROR(hipStreamSynchronize(stream));
consume_in_this_thread();
}
void Worker::commit() {
cv_.notify_all();
}
void Worker::worker_loop() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return stop_ || !tasks_.empty(); });
if (stop_) {
break;
}
if (!tasks_.empty()) {
task = tasks_.front();
tasks_.pop();
}
}
if (task) {
task();
}
}
}
} // namespace mlx::core::rocm

46
mlx/backend/rocm/worker.h Normal file
View File

@ -0,0 +1,46 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
namespace mlx::core::rocm {
// Simple worker for async task execution synchronized with HIP streams.
class Worker {
public:
Worker();
~Worker();
Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;
// Add a task to be executed
void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Commit tasks to be run after stream completion
void commit(hipStream_t stream);
// Simple commit without stream dependency
void commit();
private:
void worker_loop();
std::thread worker_thread_;
std::queue<std::function<void()>> tasks_;
std::mutex mutex_;
std::condition_variable cv_;
bool stop_{false};
};
} // namespace mlx::core::rocm

View File

@ -6,10 +6,23 @@
#include "mlx/backend/gpu/available.h"
#include "mlx/device.h"
#ifdef MLX_USE_ROCM
#include "mlx/backend/rocm/rocm.h"
#endif
namespace mlx::core {
Device& mutable_default_device() {
static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
Device::DeviceType default_type = Device::cpu;
if (gpu::is_available()) {
default_type = Device::gpu;
}
#ifdef MLX_USE_ROCM
else if (rocm::is_available()) {
default_type = Device::gpu; // ROCm devices use the generic gpu type
}
#endif
static Device default_device{default_type};
return default_device;
}
@ -38,7 +51,11 @@ bool is_available(const Device& d) {
case Device::cpu:
return cpu::is_available();
case Device::gpu:
#ifdef MLX_USE_ROCM
return gpu::is_available() || rocm::is_available();
#else
return gpu::is_available();
#endif
}
// appease compiler
return false;