increment 1: few ops and jit update

This commit is contained in:
Nripesh Niketan 2025-06-19 00:33:57 +01:00
parent 8bb8b76ae4
commit ac5adfa963
13 changed files with 1197 additions and 90 deletions

View File

@ -1,36 +1,312 @@
// Copyright © 2025 Apple Inc.
#include <hip/hip_runtime.h>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/device/binary_ops.hpp"
#include "mlx/backend/rocm/kernel_utils.hpp"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "mlx/backend/rocm/utils.h"
#include <hip/hip_cooperative_groups.h>
namespace mlx::core::rocm {
namespace mlx::core {
// Basic binary operation kernels will go here
__global__ void add_kernel(float* a, float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
c[idx] = a[idx] + b[idx];
namespace rocm {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[0]);
}
}
__global__ void multiply_kernel(float* a, float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
c[idx] = a[idx] * b[idx];
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[0], b[index]);
}
}
void launch_add(float* a, float* b, float* c, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(add_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[0]);
}
}
void launch_multiply(float* a, float* b, float* c, int n, hipStream_t stream) {
int threads = 256;
int blocks = (n + threads - 1) / threads;
hipLaunchKernelGGL(multiply_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[index]);
}
}
} // namespace mlx::core::rocm
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
const hip_array<int32_t, NDIM> shape,
const hip_array<int64_t, NDIM> a_strides,
const hip_array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
const hip_array<int32_t, MAX_DIMS> shape,
const hip_array<int64_t, MAX_DIMS> a_strides,
const hip_array<int64_t, MAX_DIMS> b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
// Binary operation support checking
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
return std::is_same_v<Out, bool>;
}
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> && is_inexact_v<In>;
}
if (std::is_same_v<Op, LogAddExp>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
std::is_same_v<Op, BitwiseXor>) {
return std::is_same_v<In, Out> && std::is_integral_v<In>;
}
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
return false;
}
} // namespace rocm
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
auto& encoder = rocm::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](hipStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (rocm::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = hip_type_t<CTYPE_IN>;
using OutType = hip_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&rocm::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
make_hip_array<NDIM>(shape),
make_hip_array<NDIM>(a_strides),
make_hip_array<NDIM>(b_strides));
});
} else {
auto kernel = rocm::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
make_hip_array(shape),
make_hip_array(a_strides),
make_hip_array(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = rocm::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = rocm::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = rocm::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = rocm::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
auto& s = out.primitive().stream(); \
binary_op_gpu<rocm::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<rocm::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<rocm::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<rocm::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<rocm::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<rocm::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<rocm::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<rocm::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<rocm::RightShift>(inputs, out, op, s);
break;
}
}
} // namespace mlx::core

View File

@ -1,20 +1,23 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/device.h"
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/rocm/worker.h"
namespace mlx::core::rocm {
#include <fmt/format.h>
DeviceStream::DeviceStream(Device& device) : device_(device) {
check_hip_error("hipStreamCreate", hipStreamCreate(&stream_));
encoder_ = std::make_unique<CommandEncoder>(*this);
}
namespace mlx::core {
namespace rocm {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
check_hip_error("hipStreamSynchronize", hipStreamSynchronize(stream_));
CHECK_HIP_ERROR(hipStreamSynchronize(stream_));
}
hipStream_t DeviceStream::schedule_hip_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
@ -23,22 +26,35 @@ hipStream_t DeviceStream::last_hip_stream() {
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) {
check_hip_error("hipSetDevice", hipSetDevice(device_));
CHECK_HIP_ERROR(hipDeviceGetAttribute(
&compute_capability_major_,
hipDeviceAttributeComputeCapabilityMajor,
device_));
CHECK_HIP_ERROR(hipDeviceGetAttribute(
&compute_capability_minor_,
hipDeviceAttributeComputeCapabilityMinor,
device_));
// Get device properties
hipDeviceProp_t prop;
check_hip_error(
"hipGetDeviceProperties", hipGetDeviceProperties(&prop, device_));
compute_capability_major_ = prop.major;
compute_capability_minor_ = prop.minor;
// Validate device requirements
int attr = 0;
CHECK_HIP_ERROR(hipDeviceGetAttribute(
&attr, hipDeviceAttributeConcurrentManagedAccess, device_));
if (attr != 1) {
// ROCm unified memory might not be available on all devices
// This is a warning rather than an error for ROCm
// TODO: Add proper ROCm unified memory checking
}
// Create rocBLAS handle
check_hip_error(
"rocblas_create_handle",
make_current();
CHECK_HIP_ERROR(
static_cast<hipError_t>(rocblas_create_handle(&rocblas_handle_)));
}
@ -49,56 +65,66 @@ Device::~Device() {
}
void Device::make_current() {
check_hip_error("hipSetDevice", hipSetDevice(device_));
// Cache current device to reduce HIP API calls
static int current = 0;
if (current != device_) {
CHECK_HIP_ERROR(hipSetDevice(device_));
current = device_;
}
}
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it != streams_.end()) {
if (it == streams_.end()) {
it = streams_.try_emplace(s.index, *this).first;
}
return it->second;
}
auto [new_it, inserted] = streams_.emplace(s.index, DeviceStream(*this));
return new_it->second;
}
CommandEncoder::CommandEncoder(DeviceStream& stream)
: device_(stream.device()), stream_(stream), worker_() {}
CommandEncoder::CommandEncoder(DeviceStream& s)
: device_(s.device()), stream_(s) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.enqueue(task);
worker_.add_task(std::move(task));
}
void CommandEncoder::end_encoding() {
// Implementation for ending encoding
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
// There is no kernel running, run completion handlers immediately.
if (!has_gpu_work_) {
worker_.consume_in_this_thread();
return;
}
has_gpu_work_ = false;
// Commit tasks
commit();
}
void CommandEncoder::commit() {
worker_.commit();
worker_.commit(stream_.last_hip_stream());
}
// Global device management
static std::unordered_map<int, std::unique_ptr<Device>> devices_;
Device& device(mlx::core::Device device) {
auto it = devices_.find(device.index);
if (it != devices_.end()) {
return *it->second;
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);
if (it == devices.end()) {
it = devices.try_emplace(device.index, device.index).first;
}
auto new_device = std::make_unique<Device>(device.index);
Device& dev_ref = *new_device;
devices_[device.index] = std::move(new_device);
return dev_ref;
return it->second;
}
DeviceStream& get_stream(Stream s) {
// Use default device (index 0) for now
return device(mlx::core::Device{mlx::core::Device::gpu, 0}).get_stream(s);
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace mlx::core::rocm
} // namespace rocm
} // namespace mlx::core

View File

@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/rocm/worker.h"
#include "mlx/stream.h"
@ -11,7 +12,9 @@
#include <unordered_map>
namespace mlx::core::rocm {
namespace mlx::core {
namespace rocm {
class Device;
class CommandEncoder;
@ -138,4 +141,6 @@ CommandEncoder& get_command_encoder(Stream s);
// Utility function to check HIP errors
void check_hip_error(const char* msg, hipError_t error);
} // namespace mlx::core::rocm
} // namespace rocm
} // namespace mlx::core

View File

@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hipcomplex.h>
namespace mlx::core::rocm {
// Arithmetic operations
struct Add {
template <typename T>
__device__ T operator()(T a, T b) {
return a + b;
}
};
struct Subtract {
template <typename T>
__device__ T operator()(T a, T b) {
return a - b;
}
};
struct Multiply {
template <typename T>
__device__ T operator()(T a, T b) {
return a * b;
}
};
struct Divide {
template <typename T>
__device__ T operator()(T a, T b) {
return a / b;
}
};
struct Power {
template <typename T>
__device__ T operator()(T a, T b) {
return powf(a, b);
}
__device__ double operator()(double a, double b) {
return pow(a, b);
}
};
struct Remainder {
template <typename T>
__device__ T operator()(T a, T b) {
return fmodf(a, b);
}
__device__ double operator()(double a, double b) {
return fmod(a, b);
}
};
// Comparison operations
struct Equal {
template <typename T>
__device__ bool operator()(T a, T b) {
return a == b;
}
};
struct NotEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return a != b;
}
};
struct Greater {
template <typename T>
__device__ bool operator()(T a, T b) {
return a > b;
}
};
struct GreaterEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return a >= b;
}
};
struct Less {
template <typename T>
__device__ bool operator()(T a, T b) {
return a < b;
}
};
struct LessEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return a <= b;
}
};
struct NaNEqual {
template <typename T>
__device__ bool operator()(T a, T b) {
return (isnan(a) && isnan(b)) || (a == b);
}
};
// Logic operations
struct LogicalAnd {
__device__ bool operator()(bool a, bool b) {
return a && b;
}
};
struct LogicalOr {
__device__ bool operator()(bool a, bool b) {
return a || b;
}
};
// Math operations
struct Maximum {
template <typename T>
__device__ T operator()(T a, T b) {
return fmaxf(a, b);
}
__device__ double operator()(double a, double b) {
return fmax(a, b);
}
};
struct Minimum {
template <typename T>
__device__ T operator()(T a, T b) {
return fminf(a, b);
}
__device__ double operator()(double a, double b) {
return fmin(a, b);
}
};
struct LogAddExp {
template <typename T>
__device__ T operator()(T a, T b) {
T max_val = fmaxf(a, b);
T min_val = fminf(a, b);
if (isinf(max_val)) {
return max_val;
}
return max_val + log1pf(expf(min_val - max_val));
}
__device__ double operator()(double a, double b) {
double max_val = fmax(a, b);
double min_val = fmin(a, b);
if (isinf(max_val)) {
return max_val;
}
return max_val + log1p(exp(min_val - max_val));
}
};
struct ArcTan2 {
template <typename T>
__device__ T operator()(T a, T b) {
return atan2f(a, b);
}
__device__ double operator()(double a, double b) {
return atan2(a, b);
}
};
// Bitwise operations
struct BitwiseAnd {
template <typename T>
__device__ T operator()(T a, T b) {
return a & b;
}
};
struct BitwiseOr {
template <typename T>
__device__ T operator()(T a, T b) {
return a | b;
}
};
struct BitwiseXor {
template <typename T>
__device__ T operator()(T a, T b) {
return a ^ b;
}
};
struct LeftShift {
template <typename T>
__device__ T operator()(T a, T b) {
return a << b;
}
};
struct RightShift {
template <typename T>
__device__ T operator()(T a, T b) {
return a >> b;
}
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,50 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/event.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
HipEvent::HipEvent() {
CHECK_HIP_ERROR(hipEventCreate(&event_));
}
HipEvent::~HipEvent() {
CHECK_HIP_ERROR(hipEventDestroy(event_));
}
void HipEvent::record(hipStream_t stream) {
CHECK_HIP_ERROR(hipEventRecord(event_, stream));
}
void HipEvent::wait() {
CHECK_HIP_ERROR(hipEventSynchronize(event_));
}
bool HipEvent::query() const {
hipError_t status = hipEventQuery(event_);
if (status == hipSuccess) {
return true;
} else if (status == hipErrorNotReady) {
return false;
} else {
CHECK_HIP_ERROR(status);
return false;
}
}
SharedEvent::SharedEvent() = default;
void SharedEvent::notify() {
std::lock_guard<std::mutex> lock(mutex_);
ready_ = true;
cv_.notify_one();
}
void SharedEvent::wait() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return ready_; });
ready_ = false;
}
} // namespace mlx::core::rocm

48
mlx/backend/rocm/event.h Normal file
View File

@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <condition_variable>
#include <memory>
#include <mutex>
namespace mlx::core::rocm {
// HIP event managed with RAII.
class HipEvent {
public:
HipEvent();
~HipEvent();
HipEvent(const HipEvent&) = delete;
HipEvent& operator=(const HipEvent&) = delete;
void record(hipStream_t stream);
void wait();
bool query() const;
operator hipEvent_t() const {
return event_;
}
private:
hipEvent_t event_;
};
// Shared event for worker thread synchronization.
class SharedEvent {
public:
SharedEvent();
void notify();
void wait();
private:
std::mutex mutex_;
std::condition_variable cv_;
bool ready_{false};
};
} // namespace mlx::core::rocm

View File

@ -0,0 +1,167 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/jit_module.h"
#include "mlx/backend/rocm/utils.h"
#include <fmt/format.h>
#include <mutex>
#include <sstream>
namespace mlx::core::rocm {
JitModule::JitModule(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags,
bool verbose) {
compile(kernel_name, kernel_source, template_args, compiler_flags, verbose);
}
JitModule::~JitModule() {
if (kernel_) {
// No hipFunctionDestroy equivalent in HIP
}
if (module_) {
CHECK_HIP_ERROR(hipModuleUnload(module_));
}
if (program_) {
hiprtcDestroyProgram(&program_);
}
}
void JitModule::compile(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags,
bool verbose) {
// Create HIPRTC program
CHECK_HIP_ERROR(hiprtcCreateProgram(
&program_,
kernel_source.c_str(),
kernel_name.c_str(),
0,
nullptr,
nullptr));
// Build compiler options
std::vector<const char*> options;
std::vector<std::string> option_strings;
// Add default options
option_strings.push_back("--std=c++17");
option_strings.push_back("-O3");
option_strings.push_back("-DMLX_USE_ROCM");
// Add user-provided flags
for (const auto& flag : compiler_flags) {
option_strings.push_back(flag);
}
// Add template arguments
for (const auto& arg : template_args) {
option_strings.push_back("-D" + arg);
}
// Convert to char* array
for (const auto& option : option_strings) {
options.push_back(option.c_str());
}
// Compile the program
hiprtcResult compile_result =
hiprtcCompileProgram(program_, options.size(), options.data());
// Get compilation log
size_t log_size;
CHECK_HIP_ERROR(hiprtcGetProgramLogSize(program_, &log_size));
if (log_size > 1) {
std::vector<char> log(log_size);
CHECK_HIP_ERROR(hiprtcGetProgramLog(program_, log.data()));
if (verbose || compile_result != HIPRTC_SUCCESS) {
fmt::print(
"HIPRTC compilation log for {}:\n{}\n", kernel_name, log.data());
}
}
if (compile_result != HIPRTC_SUCCESS) {
throw std::runtime_error(
fmt::format("HIPRTC compilation failed for kernel {}", kernel_name));
}
// Get compiled code
size_t code_size;
CHECK_HIP_ERROR(hiprtcGetCodeSize(program_, &code_size));
std::vector<char> code(code_size);
CHECK_HIP_ERROR(hiprtcGetCode(program_, code.data()));
// Load module
CHECK_HIP_ERROR(hipModuleLoadData(&module_, code.data()));
// Get kernel function
CHECK_HIP_ERROR(hipModuleGetFunction(&kernel_, module_, kernel_name.c_str()));
}
JitCache& JitCache::instance() {
static JitCache cache;
return cache;
}
std::shared_ptr<JitModule> JitCache::get_or_create(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) {
std::string key =
make_key(kernel_name, kernel_source, template_args, compiler_flags);
std::lock_guard<std::mutex> lock(mutex_);
auto it = cache_.find(key);
if (it != cache_.end()) {
if (auto module = it->second.lock()) {
return module;
} else {
cache_.erase(it);
}
}
auto module = std::make_shared<JitModule>(
kernel_name, kernel_source, template_args, compiler_flags);
cache_[key] = module;
return module;
}
std::string JitCache::make_key(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) const {
std::ostringstream oss;
oss << kernel_name << "|" << kernel_source;
for (const auto& arg : template_args) {
oss << "|" << arg;
}
for (const auto& flag : compiler_flags) {
oss << "|" << flag;
}
return oss.str();
}
std::shared_ptr<JitModule> make_jit_kernel(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) {
return JitCache::instance().get_or_create(
kernel_name, kernel_source, template_args, compiler_flags);
}
} // namespace mlx::core::rocm

View File

@ -0,0 +1,100 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace mlx::core::rocm {
// JIT compilation module for ROCm
class JitModule {
public:
JitModule(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args = {},
const std::vector<std::string>& compiler_flags = {},
bool verbose = false);
~JitModule();
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
// Get the compiled kernel function
hipFunction_t get_kernel() const {
return kernel_;
}
// Launch the kernel with given arguments
template <typename... Args>
void launch(
dim3 grid_dims,
dim3 block_dims,
size_t shared_memory,
hipStream_t stream,
Args&&... args) {
void* kernel_args[] = {(void*)&args...};
CHECK_HIP_ERROR(hipModuleLaunchKernel(
kernel_,
grid_dims.x,
grid_dims.y,
grid_dims.z,
block_dims.x,
block_dims.y,
block_dims.z,
shared_memory,
stream,
kernel_args,
nullptr));
}
private:
void compile(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags,
bool verbose);
hiprtcProgram program_{nullptr};
hipModule_t module_{nullptr};
hipFunction_t kernel_{nullptr};
};
// JIT cache for compiled modules
class JitCache {
public:
static JitCache& instance();
std::shared_ptr<JitModule> get_or_create(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args = {},
const std::vector<std::string>& compiler_flags = {});
private:
std::unordered_map<std::string, std::weak_ptr<JitModule>> cache_;
std::mutex mutex_;
std::string make_key(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args,
const std::vector<std::string>& compiler_flags) const;
};
// Helper function to create and cache JIT modules
std::shared_ptr<JitModule> make_jit_kernel(
const std::string& kernel_name,
const std::string& kernel_source,
const std::vector<std::string>& template_args = {},
const std::vector<std::string>& compiler_flags = {});
} // namespace mlx::core::rocm

View File

@ -0,0 +1,135 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <hip/hip_runtime.h>
#include <array>
namespace mlx::core::rocm {
// Constants
constexpr int MAX_DIMS = 8;
// HIP array type for passing arrays to kernels
template <typename T, int N>
using hip_array = std::array<T, N>;
// Helper to create hip_array from vector
template <int N, typename T>
__host__ hip_array<T, N> make_hip_array(const std::vector<T>& vec) {
hip_array<T, N> arr;
for (int i = 0; i < N && i < vec.size(); ++i) {
arr[i] = vec[i];
}
return arr;
}
template <typename T>
__host__ hip_array<T, MAX_DIMS> make_hip_array(const std::vector<T>& vec) {
return make_hip_array<MAX_DIMS>(vec);
}
// Type mapping from MLX types to HIP types
template <typename T>
using hip_type_t = T;
template <>
using hip_type_t<float16> = __half;
template <>
using hip_type_t<bfloat16> = __hip_bfloat16;
template <>
using hip_type_t<complex64> = hipFloatComplex;
// Element to location mapping for general broadcasting
template <int NDIM>
__device__ std::pair<int64_t, int64_t> elem_to_loc_nd(
int64_t elem,
const int32_t* shape,
const int64_t* a_strides,
const int64_t* b_strides) {
int64_t a_idx = 0;
int64_t b_idx = 0;
for (int i = NDIM - 1; i >= 0; --i) {
int64_t pos_in_dim = elem % shape[i];
elem /= shape[i];
a_idx += pos_in_dim * a_strides[i];
b_idx += pos_in_dim * b_strides[i];
}
return {a_idx, b_idx};
}
// 4D specialization for performance
__device__ inline std::pair<int64_t, int64_t> elem_to_loc_4d(
int64_t elem,
const int32_t* shape,
const int64_t* a_strides,
const int64_t* b_strides,
int ndim) {
int64_t a_idx = 0;
int64_t b_idx = 0;
for (int i = ndim - 1; i >= 0; --i) {
int64_t pos_in_dim = elem % shape[i];
elem /= shape[i];
a_idx += pos_in_dim * a_strides[i];
b_idx += pos_in_dim * b_strides[i];
}
return {a_idx, b_idx};
}
// Launch configuration calculation
template <typename Kernel>
std::pair<dim3, dim3>
get_launch_args(Kernel kernel, const array& out, bool large = false) {
int threads_per_block = 256;
int64_t total_threads = out.size();
if (large) {
// For large arrays, use more blocks
int64_t blocks =
(total_threads + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
} else {
int blocks = (total_threads + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
}
}
template <typename Kernel>
std::pair<dim3, dim3> get_launch_args(
Kernel kernel,
int64_t size,
const std::vector<int>& shape,
const std::vector<size_t>& strides,
bool large = false) {
int threads_per_block = 256;
if (large) {
int64_t blocks = (size + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
} else {
int blocks = (size + threads_per_block - 1) / threads_per_block;
return {dim3(blocks), dim3(threads_per_block)};
}
}
// Cooperative groups thread rank equivalent
namespace cooperative_groups {
class grid_group {
public:
__device__ int64_t thread_rank() const {
return blockIdx.x * blockDim.x + threadIdx.x;
}
};
__device__ grid_group this_grid() {
return grid_group{};
}
} // namespace cooperative_groups
} // namespace mlx::core::rocm

View File

@ -1,17 +1,46 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/utils.h"
#include <sstream>
#include <stdexcept>
#include "mlx/backend/rocm/device.h"
#include "mlx/dtype_utils.h"
namespace mlx::core::rocm {
#include <fmt/format.h>
void check_hip_error(const char* msg, hipError_t error) {
if (error != hipSuccess) {
std::ostringstream oss;
oss << "[ROCm] " << msg << ": " << hipGetErrorString(error);
throw std::runtime_error(oss.str());
namespace mlx::core {
HipStream::HipStream(rocm::Device& device) {
device.make_current();
CHECK_HIP_ERROR(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking));
}
HipStream::~HipStream() {
CHECK_HIP_ERROR(hipStreamDestroy(stream_));
}
void check_hip_error(const char* name, hipError_t err) {
if (err != hipSuccess) {
throw std::runtime_error(
fmt::format("{} failed: {}", name, hipGetErrorString(err)));
}
}
} // namespace mlx::core::rocm
const char* dtype_to_hip_type(const Dtype& dtype) {
if (dtype == float16) {
return "__half";
}
if (dtype == bfloat16) {
return "__hip_bfloat16";
}
if (dtype == complex64) {
return "hipFloatComplex";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (dtype == DTYPE) { \
return #CPP_TYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return nullptr;
}
} // namespace mlx::core

View File

@ -1,12 +1,43 @@
// Copyright © 2025 Apple Inc.
// This file includes utilities that are used by C++ code (i.e. .cpp files).
#pragma once
#include <hip/hip_runtime.h>
namespace mlx::core::rocm {
namespace mlx::core {
// Utility function to check HIP errors
void check_hip_error(const char* msg, hipError_t error);
namespace rocm {
class Device;
}
} // namespace mlx::core::rocm
struct Dtype;
// HIP stream managed with RAII.
class HipStream {
public:
explicit HipStream(rocm::Device& device);
~HipStream();
HipStream(const HipStream&) = delete;
HipStream& operator=(const HipStream&) = delete;
operator hipStream_t() const {
return stream_;
}
private:
hipStream_t stream_;
};
// Throw exception if the HIP API does not succeed.
void check_hip_error(const char* name, hipError_t err);
// The macro version that prints the command that failed.
#define CHECK_HIP_ERROR(cmd) check_hip_error(#cmd, (cmd))
// Convert Dtype to HIP C++ types.
const char* dtype_to_hip_type(const Dtype& dtype);
} // namespace mlx::core

View File

@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/rocm/worker.h"
#include "mlx/backend/rocm/utils.h"
namespace mlx::core::rocm {
@ -17,7 +18,7 @@ Worker::~Worker() {
}
}
void Worker::enqueue(std::function<void()> task) {
void Worker::add_task(std::function<void()> task) {
{
std::lock_guard<std::mutex> lock(mutex_);
tasks_.push(task);
@ -25,14 +26,28 @@ void Worker::enqueue(std::function<void()> task) {
cv_.notify_one();
}
void Worker::commit() {
void Worker::consume_in_this_thread() {
std::queue<std::function<void()>> local_tasks;
{
std::lock_guard<std::mutex> lock(mutex_);
committed_ = true;
local_tasks.swap(tasks_);
}
void Worker::join() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return tasks_.empty() && committed_; });
while (!local_tasks.empty()) {
auto task = local_tasks.front();
local_tasks.pop();
task();
}
}
void Worker::commit(hipStream_t stream) {
// Synchronize with stream and then process tasks
CHECK_HIP_ERROR(hipStreamSynchronize(stream));
consume_in_this_thread();
}
void Worker::commit() {
cv_.notify_all();
}
void Worker::worker_loop() {

View File

@ -3,15 +3,16 @@
#pragma once
#include <hip/hip_runtime.h>
#include <condition_variable>
#include <functional>
#include <future>
#include <mutex>
#include <queue>
#include <thread>
namespace mlx::core::rocm {
using HipStream = hipStream_t;
// Simple worker for async task execution synchronized with HIP streams.
class Worker {
public:
Worker();
@ -20,9 +21,17 @@ class Worker {
Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;
void enqueue(std::function<void()> task);
// Add a task to be executed
void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Commit tasks to be run after stream completion
void commit(hipStream_t stream);
// Simple commit without stream dependency
void commit();
void join();
private:
void worker_loop();
@ -32,7 +41,6 @@ class Worker {
std::mutex mutex_;
std::condition_variable cv_;
bool stop_{false};
bool committed_{false};
};
} // namespace mlx::core::rocm