From ac5adfa9634ec7f2b3b003305173cdffb1461a2c Mon Sep 17 00:00:00 2001 From: Nripesh Niketan Date: Thu, 19 Jun 2025 00:33:57 +0100 Subject: [PATCH] increment 1: few ops and jit update --- mlx/backend/rocm/binary.hip | 318 +++++++++++++++++++++++-- mlx/backend/rocm/device.cpp | 108 +++++---- mlx/backend/rocm/device.h | 9 +- mlx/backend/rocm/device/binary_ops.hpp | 217 +++++++++++++++++ mlx/backend/rocm/event.cpp | 50 ++++ mlx/backend/rocm/event.h | 48 ++++ mlx/backend/rocm/jit_module.cpp | 167 +++++++++++++ mlx/backend/rocm/jit_module.h | 100 ++++++++ mlx/backend/rocm/kernel_utils.hpp | 135 +++++++++++ mlx/backend/rocm/utils.cpp | 47 +++- mlx/backend/rocm/utils.h | 39 ++- mlx/backend/rocm/worker.cpp | 29 ++- mlx/backend/rocm/worker.h | 20 +- 13 files changed, 1197 insertions(+), 90 deletions(-) create mode 100644 mlx/backend/rocm/device/binary_ops.hpp create mode 100644 mlx/backend/rocm/event.cpp create mode 100644 mlx/backend/rocm/event.h create mode 100644 mlx/backend/rocm/jit_module.cpp create mode 100644 mlx/backend/rocm/jit_module.h create mode 100644 mlx/backend/rocm/kernel_utils.hpp diff --git a/mlx/backend/rocm/binary.hip b/mlx/backend/rocm/binary.hip index 14b48bfc9..8976befa2 100644 --- a/mlx/backend/rocm/binary.hip +++ b/mlx/backend/rocm/binary.hip @@ -1,36 +1,312 @@ // Copyright © 2025 Apple Inc. -#include +#include "mlx/backend/common/binary.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/binary_ops.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" -#include "mlx/backend/rocm/utils.h" +#include -namespace mlx::core::rocm { +namespace mlx::core { -// Basic binary operation kernels will go here -__global__ void add_kernel(float* a, float* b, float* c, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - c[idx] = a[idx] + b[idx]; +namespace rocm { + +namespace cg = cooperative_groups; + +template +__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[0]); } } -__global__ void multiply_kernel(float* a, float* b, float* c, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - c[idx] = a[idx] * b[idx]; +template +__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[index]); } } -void launch_add(float* a, float* b, float* c, int n, hipStream_t stream) { - int threads = 256; - int blocks = (n + threads - 1) / threads; - hipLaunchKernelGGL(add_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n); +template +__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[0]); + } } -void launch_multiply(float* a, float* b, float* c, int n, hipStream_t stream) { - int threads = 256; - int blocks = (n + threads - 1) / threads; - hipLaunchKernelGGL(multiply_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n); +template +__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[index]); + } } -} // namespace mlx::core::rocm \ No newline at end of file +template +__global__ void binary_g_nd( + const In* a, + const In* b, + Out* out, + IdxT size, + const hip_array shape, + const hip_array a_strides, + const hip_array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out, + IdxT size, + const hip_array shape, + const hip_array a_strides, + const hip_array b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +// Binary operation support checking +template +constexpr bool supports_binary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && std::is_integral_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + return false; +} + +} // namespace rocm + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out = outputs[0]; + if (out.size() == 0) { + return; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (rocm::supports_binary_op()) { + using InType = hip_type_t; + using OutType = hip_type_t; + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + bool large = a.data_size() > INT32_MAX || + b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = + &rocm::binary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream, + a.data(), + b.data(), + out.data(), + out.size(), + make_hip_array(shape), + make_hip_array(a_strides), + make_hip_array(b_strides)); + }); + } else { + auto kernel = rocm::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream, + a.data(), + b.data(), + out.data(), + out.size(), + make_hip_array(shape), + make_hip_array(a_strides), + make_hip_array(b_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = rocm::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = rocm::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = rocm::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = rocm::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); + hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream, + a.data(), + b.data(), + out.data(), + out.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +template +void binary_op_gpu( + const std::vector& inputs, + array& out, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + std::vector outputs{out}; + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +#define BINARY_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + auto& s = outputs[0].primitive().stream(); \ + binary_op_gpu(inputs, outputs, get_primitive_string(this), s); \ + } + +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU(Remainder) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(LogAddExp) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Subtract) + +void Equal::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + if (equal_nan_) { + binary_op_gpu(inputs, out, op, s); + } else { + binary_op_gpu(inputs, out, op, s); + } +} + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, op, s); + break; + } +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/device.cpp b/mlx/backend/rocm/device.cpp index 9ab97ea20..88fb997bc 100644 --- a/mlx/backend/rocm/device.cpp +++ b/mlx/backend/rocm/device.cpp @@ -1,20 +1,23 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/rocm/device.h" -#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/backend/rocm/worker.h" -namespace mlx::core::rocm { +#include -DeviceStream::DeviceStream(Device& device) : device_(device) { - check_hip_error("hipStreamCreate", hipStreamCreate(&stream_)); - encoder_ = std::make_unique(*this); -} +namespace mlx::core { + +namespace rocm { + +DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} void DeviceStream::synchronize() { - check_hip_error("hipStreamSynchronize", hipStreamSynchronize(stream_)); + CHECK_HIP_ERROR(hipStreamSynchronize(stream_)); } hipStream_t DeviceStream::schedule_hip_stream() { + // TODO: Return a stream that maximizes parallelism. return stream_; } @@ -23,22 +26,35 @@ hipStream_t DeviceStream::last_hip_stream() { } CommandEncoder& DeviceStream::get_encoder() { + if (!encoder_) { + encoder_ = std::make_unique(*this); + } return *encoder_; } Device::Device(int device) : device_(device) { - check_hip_error("hipSetDevice", hipSetDevice(device_)); + CHECK_HIP_ERROR(hipDeviceGetAttribute( + &compute_capability_major_, + hipDeviceAttributeComputeCapabilityMajor, + device_)); + CHECK_HIP_ERROR(hipDeviceGetAttribute( + &compute_capability_minor_, + hipDeviceAttributeComputeCapabilityMinor, + device_)); - // Get device properties - hipDeviceProp_t prop; - check_hip_error( - "hipGetDeviceProperties", hipGetDeviceProperties(&prop, device_)); - compute_capability_major_ = prop.major; - compute_capability_minor_ = prop.minor; + // Validate device requirements + int attr = 0; + CHECK_HIP_ERROR(hipDeviceGetAttribute( + &attr, hipDeviceAttributeConcurrentManagedAccess, device_)); + if (attr != 1) { + // ROCm unified memory might not be available on all devices + // This is a warning rather than an error for ROCm + // TODO: Add proper ROCm unified memory checking + } // Create rocBLAS handle - check_hip_error( - "rocblas_create_handle", + make_current(); + CHECK_HIP_ERROR( static_cast(rocblas_create_handle(&rocblas_handle_))); } @@ -49,56 +65,66 @@ Device::~Device() { } void Device::make_current() { - check_hip_error("hipSetDevice", hipSetDevice(device_)); + // Cache current device to reduce HIP API calls + static int current = 0; + if (current != device_) { + CHECK_HIP_ERROR(hipSetDevice(device_)); + current = device_; + } } DeviceStream& Device::get_stream(Stream s) { auto it = streams_.find(s.index); - if (it != streams_.end()) { - return it->second; + if (it == streams_.end()) { + it = streams_.try_emplace(s.index, *this).first; } - - auto [new_it, inserted] = streams_.emplace(s.index, DeviceStream(*this)); - return new_it->second; + return it->second; } -CommandEncoder::CommandEncoder(DeviceStream& stream) - : device_(stream.device()), stream_(stream), worker_() {} +CommandEncoder::CommandEncoder(DeviceStream& s) + : device_(s.device()), stream_(s) {} void CommandEncoder::add_completed_handler(std::function task) { - worker_.enqueue(task); + worker_.add_task(std::move(task)); } void CommandEncoder::end_encoding() { - // Implementation for ending encoding + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + + // There is no kernel running, run completion handlers immediately. + if (!has_gpu_work_) { + worker_.consume_in_this_thread(); + return; + } + has_gpu_work_ = false; + + // Commit tasks + commit(); } void CommandEncoder::commit() { - worker_.commit(); + worker_.commit(stream_.last_hip_stream()); } -// Global device management -static std::unordered_map> devices_; - Device& device(mlx::core::Device device) { - auto it = devices_.find(device.index); - if (it != devices_.end()) { - return *it->second; + static std::unordered_map devices; + auto it = devices.find(device.index); + if (it == devices.end()) { + it = devices.try_emplace(device.index, device.index).first; } - - auto new_device = std::make_unique(device.index); - Device& dev_ref = *new_device; - devices_[device.index] = std::move(new_device); - return dev_ref; + return it->second; } DeviceStream& get_stream(Stream s) { - // Use default device (index 0) for now - return device(mlx::core::Device{mlx::core::Device::gpu, 0}).get_stream(s); + return device(s.device).get_stream(s); } CommandEncoder& get_command_encoder(Stream s) { return get_stream(s).get_encoder(); } -} // namespace mlx::core::rocm \ No newline at end of file +} // namespace rocm + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/device.h b/mlx/backend/rocm/device.h index bd122d547..6a9c18a07 100644 --- a/mlx/backend/rocm/device.h +++ b/mlx/backend/rocm/device.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/rocm/utils.h" #include "mlx/backend/rocm/worker.h" #include "mlx/stream.h" @@ -11,7 +12,9 @@ #include -namespace mlx::core::rocm { +namespace mlx::core { + +namespace rocm { class Device; class CommandEncoder; @@ -138,4 +141,6 @@ CommandEncoder& get_command_encoder(Stream s); // Utility function to check HIP errors void check_hip_error(const char* msg, hipError_t error); -} // namespace mlx::core::rocm \ No newline at end of file +} // namespace rocm + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/device/binary_ops.hpp b/mlx/backend/rocm/device/binary_ops.hpp new file mode 100644 index 000000000..01766f2cc --- /dev/null +++ b/mlx/backend/rocm/device/binary_ops.hpp @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// Arithmetic operations +struct Add { + template + __device__ T operator()(T a, T b) { + return a + b; + } +}; + +struct Subtract { + template + __device__ T operator()(T a, T b) { + return a - b; + } +}; + +struct Multiply { + template + __device__ T operator()(T a, T b) { + return a * b; + } +}; + +struct Divide { + template + __device__ T operator()(T a, T b) { + return a / b; + } +}; + +struct Power { + template + __device__ T operator()(T a, T b) { + return powf(a, b); + } + + __device__ double operator()(double a, double b) { + return pow(a, b); + } +}; + +struct Remainder { + template + __device__ T operator()(T a, T b) { + return fmodf(a, b); + } + + __device__ double operator()(double a, double b) { + return fmod(a, b); + } +}; + +// Comparison operations +struct Equal { + template + __device__ bool operator()(T a, T b) { + return a == b; + } +}; + +struct NotEqual { + template + __device__ bool operator()(T a, T b) { + return a != b; + } +}; + +struct Greater { + template + __device__ bool operator()(T a, T b) { + return a > b; + } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T a, T b) { + return a >= b; + } +}; + +struct Less { + template + __device__ bool operator()(T a, T b) { + return a < b; + } +}; + +struct LessEqual { + template + __device__ bool operator()(T a, T b) { + return a <= b; + } +}; + +struct NaNEqual { + template + __device__ bool operator()(T a, T b) { + return (isnan(a) && isnan(b)) || (a == b); + } +}; + +// Logic operations +struct LogicalAnd { + __device__ bool operator()(bool a, bool b) { + return a && b; + } +}; + +struct LogicalOr { + __device__ bool operator()(bool a, bool b) { + return a || b; + } +}; + +// Math operations +struct Maximum { + template + __device__ T operator()(T a, T b) { + return fmaxf(a, b); + } + + __device__ double operator()(double a, double b) { + return fmax(a, b); + } +}; + +struct Minimum { + template + __device__ T operator()(T a, T b) { + return fminf(a, b); + } + + __device__ double operator()(double a, double b) { + return fmin(a, b); + } +}; + +struct LogAddExp { + template + __device__ T operator()(T a, T b) { + T max_val = fmaxf(a, b); + T min_val = fminf(a, b); + if (isinf(max_val)) { + return max_val; + } + return max_val + log1pf(expf(min_val - max_val)); + } + + __device__ double operator()(double a, double b) { + double max_val = fmax(a, b); + double min_val = fmin(a, b); + if (isinf(max_val)) { + return max_val; + } + return max_val + log1p(exp(min_val - max_val)); + } +}; + +struct ArcTan2 { + template + __device__ T operator()(T a, T b) { + return atan2f(a, b); + } + + __device__ double operator()(double a, double b) { + return atan2(a, b); + } +}; + +// Bitwise operations +struct BitwiseAnd { + template + __device__ T operator()(T a, T b) { + return a & b; + } +}; + +struct BitwiseOr { + template + __device__ T operator()(T a, T b) { + return a | b; + } +}; + +struct BitwiseXor { + template + __device__ T operator()(T a, T b) { + return a ^ b; + } +}; + +struct LeftShift { + template + __device__ T operator()(T a, T b) { + return a << b; + } +}; + +struct RightShift { + template + __device__ T operator()(T a, T b) { + return a >> b; + } +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/event.cpp b/mlx/backend/rocm/event.cpp new file mode 100644 index 000000000..a1ff81622 --- /dev/null +++ b/mlx/backend/rocm/event.cpp @@ -0,0 +1,50 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/event.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +HipEvent::HipEvent() { + CHECK_HIP_ERROR(hipEventCreate(&event_)); +} + +HipEvent::~HipEvent() { + CHECK_HIP_ERROR(hipEventDestroy(event_)); +} + +void HipEvent::record(hipStream_t stream) { + CHECK_HIP_ERROR(hipEventRecord(event_, stream)); +} + +void HipEvent::wait() { + CHECK_HIP_ERROR(hipEventSynchronize(event_)); +} + +bool HipEvent::query() const { + hipError_t status = hipEventQuery(event_); + if (status == hipSuccess) { + return true; + } else if (status == hipErrorNotReady) { + return false; + } else { + CHECK_HIP_ERROR(status); + return false; + } +} + +SharedEvent::SharedEvent() = default; + +void SharedEvent::notify() { + std::lock_guard lock(mutex_); + ready_ = true; + cv_.notify_one(); +} + +void SharedEvent::wait() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return ready_; }); + ready_ = false; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/event.h b/mlx/backend/rocm/event.h new file mode 100644 index 000000000..1a9d5f5a6 --- /dev/null +++ b/mlx/backend/rocm/event.h @@ -0,0 +1,48 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include +#include +#include + +namespace mlx::core::rocm { + +// HIP event managed with RAII. +class HipEvent { + public: + HipEvent(); + ~HipEvent(); + + HipEvent(const HipEvent&) = delete; + HipEvent& operator=(const HipEvent&) = delete; + + void record(hipStream_t stream); + void wait(); + bool query() const; + + operator hipEvent_t() const { + return event_; + } + + private: + hipEvent_t event_; +}; + +// Shared event for worker thread synchronization. +class SharedEvent { + public: + SharedEvent(); + + void notify(); + void wait(); + + private: + std::mutex mutex_; + std::condition_variable cv_; + bool ready_{false}; +}; + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/jit_module.cpp b/mlx/backend/rocm/jit_module.cpp new file mode 100644 index 000000000..cdda490d5 --- /dev/null +++ b/mlx/backend/rocm/jit_module.cpp @@ -0,0 +1,167 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/jit_module.h" +#include "mlx/backend/rocm/utils.h" + +#include +#include +#include + +namespace mlx::core::rocm { + +JitModule::JitModule( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags, + bool verbose) { + compile(kernel_name, kernel_source, template_args, compiler_flags, verbose); +} + +JitModule::~JitModule() { + if (kernel_) { + // No hipFunctionDestroy equivalent in HIP + } + if (module_) { + CHECK_HIP_ERROR(hipModuleUnload(module_)); + } + if (program_) { + hiprtcDestroyProgram(&program_); + } +} + +void JitModule::compile( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags, + bool verbose) { + // Create HIPRTC program + CHECK_HIP_ERROR(hiprtcCreateProgram( + &program_, + kernel_source.c_str(), + kernel_name.c_str(), + 0, + nullptr, + nullptr)); + + // Build compiler options + std::vector options; + std::vector option_strings; + + // Add default options + option_strings.push_back("--std=c++17"); + option_strings.push_back("-O3"); + option_strings.push_back("-DMLX_USE_ROCM"); + + // Add user-provided flags + for (const auto& flag : compiler_flags) { + option_strings.push_back(flag); + } + + // Add template arguments + for (const auto& arg : template_args) { + option_strings.push_back("-D" + arg); + } + + // Convert to char* array + for (const auto& option : option_strings) { + options.push_back(option.c_str()); + } + + // Compile the program + hiprtcResult compile_result = + hiprtcCompileProgram(program_, options.size(), options.data()); + + // Get compilation log + size_t log_size; + CHECK_HIP_ERROR(hiprtcGetProgramLogSize(program_, &log_size)); + + if (log_size > 1) { + std::vector log(log_size); + CHECK_HIP_ERROR(hiprtcGetProgramLog(program_, log.data())); + + if (verbose || compile_result != HIPRTC_SUCCESS) { + fmt::print( + "HIPRTC compilation log for {}:\n{}\n", kernel_name, log.data()); + } + } + + if (compile_result != HIPRTC_SUCCESS) { + throw std::runtime_error( + fmt::format("HIPRTC compilation failed for kernel {}", kernel_name)); + } + + // Get compiled code + size_t code_size; + CHECK_HIP_ERROR(hiprtcGetCodeSize(program_, &code_size)); + + std::vector code(code_size); + CHECK_HIP_ERROR(hiprtcGetCode(program_, code.data())); + + // Load module + CHECK_HIP_ERROR(hipModuleLoadData(&module_, code.data())); + + // Get kernel function + CHECK_HIP_ERROR(hipModuleGetFunction(&kernel_, module_, kernel_name.c_str())); +} + +JitCache& JitCache::instance() { + static JitCache cache; + return cache; +} + +std::shared_ptr JitCache::get_or_create( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) { + std::string key = + make_key(kernel_name, kernel_source, template_args, compiler_flags); + + std::lock_guard lock(mutex_); + + auto it = cache_.find(key); + if (it != cache_.end()) { + if (auto module = it->second.lock()) { + return module; + } else { + cache_.erase(it); + } + } + + auto module = std::make_shared( + kernel_name, kernel_source, template_args, compiler_flags); + cache_[key] = module; + return module; +} + +std::string JitCache::make_key( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) const { + std::ostringstream oss; + oss << kernel_name << "|" << kernel_source; + + for (const auto& arg : template_args) { + oss << "|" << arg; + } + + for (const auto& flag : compiler_flags) { + oss << "|" << flag; + } + + return oss.str(); +} + +std::shared_ptr make_jit_kernel( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) { + return JitCache::instance().get_or_create( + kernel_name, kernel_source, template_args, compiler_flags); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/jit_module.h b/mlx/backend/rocm/jit_module.h new file mode 100644 index 000000000..55b655c4d --- /dev/null +++ b/mlx/backend/rocm/jit_module.h @@ -0,0 +1,100 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// JIT compilation module for ROCm +class JitModule { + public: + JitModule( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args = {}, + const std::vector& compiler_flags = {}, + bool verbose = false); + + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + + // Get the compiled kernel function + hipFunction_t get_kernel() const { + return kernel_; + } + + // Launch the kernel with given arguments + template + void launch( + dim3 grid_dims, + dim3 block_dims, + size_t shared_memory, + hipStream_t stream, + Args&&... args) { + void* kernel_args[] = {(void*)&args...}; + CHECK_HIP_ERROR(hipModuleLaunchKernel( + kernel_, + grid_dims.x, + grid_dims.y, + grid_dims.z, + block_dims.x, + block_dims.y, + block_dims.z, + shared_memory, + stream, + kernel_args, + nullptr)); + } + + private: + void compile( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags, + bool verbose); + + hiprtcProgram program_{nullptr}; + hipModule_t module_{nullptr}; + hipFunction_t kernel_{nullptr}; +}; + +// JIT cache for compiled modules +class JitCache { + public: + static JitCache& instance(); + + std::shared_ptr get_or_create( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args = {}, + const std::vector& compiler_flags = {}); + + private: + std::unordered_map> cache_; + std::mutex mutex_; + + std::string make_key( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args, + const std::vector& compiler_flags) const; +}; + +// Helper function to create and cache JIT modules +std::shared_ptr make_jit_kernel( + const std::string& kernel_name, + const std::string& kernel_source, + const std::vector& template_args = {}, + const std::vector& compiler_flags = {}); + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/kernel_utils.hpp b/mlx/backend/rocm/kernel_utils.hpp new file mode 100644 index 000000000..f694fd008 --- /dev/null +++ b/mlx/backend/rocm/kernel_utils.hpp @@ -0,0 +1,135 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// Constants +constexpr int MAX_DIMS = 8; + +// HIP array type for passing arrays to kernels +template +using hip_array = std::array; + +// Helper to create hip_array from vector +template +__host__ hip_array make_hip_array(const std::vector& vec) { + hip_array arr; + for (int i = 0; i < N && i < vec.size(); ++i) { + arr[i] = vec[i]; + } + return arr; +} + +template +__host__ hip_array make_hip_array(const std::vector& vec) { + return make_hip_array(vec); +} + +// Type mapping from MLX types to HIP types +template +using hip_type_t = T; + +template <> +using hip_type_t = __half; + +template <> +using hip_type_t = __hip_bfloat16; + +template <> +using hip_type_t = hipFloatComplex; + +// Element to location mapping for general broadcasting +template +__device__ std::pair elem_to_loc_nd( + int64_t elem, + const int32_t* shape, + const int64_t* a_strides, + const int64_t* b_strides) { + int64_t a_idx = 0; + int64_t b_idx = 0; + + for (int i = NDIM - 1; i >= 0; --i) { + int64_t pos_in_dim = elem % shape[i]; + elem /= shape[i]; + a_idx += pos_in_dim * a_strides[i]; + b_idx += pos_in_dim * b_strides[i]; + } + + return {a_idx, b_idx}; +} + +// 4D specialization for performance +__device__ inline std::pair elem_to_loc_4d( + int64_t elem, + const int32_t* shape, + const int64_t* a_strides, + const int64_t* b_strides, + int ndim) { + int64_t a_idx = 0; + int64_t b_idx = 0; + + for (int i = ndim - 1; i >= 0; --i) { + int64_t pos_in_dim = elem % shape[i]; + elem /= shape[i]; + a_idx += pos_in_dim * a_strides[i]; + b_idx += pos_in_dim * b_strides[i]; + } + + return {a_idx, b_idx}; +} + +// Launch configuration calculation +template +std::pair +get_launch_args(Kernel kernel, const array& out, bool large = false) { + int threads_per_block = 256; + int64_t total_threads = out.size(); + + if (large) { + // For large arrays, use more blocks + int64_t blocks = + (total_threads + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } else { + int blocks = (total_threads + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } +} + +template +std::pair get_launch_args( + Kernel kernel, + int64_t size, + const std::vector& shape, + const std::vector& strides, + bool large = false) { + int threads_per_block = 256; + + if (large) { + int64_t blocks = (size + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } else { + int blocks = (size + threads_per_block - 1) / threads_per_block; + return {dim3(blocks), dim3(threads_per_block)}; + } +} + +// Cooperative groups thread rank equivalent +namespace cooperative_groups { +class grid_group { + public: + __device__ int64_t thread_rank() const { + return blockIdx.x * blockDim.x + threadIdx.x; + } +}; + +__device__ grid_group this_grid() { + return grid_group{}; +} +} // namespace cooperative_groups + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/utils.cpp b/mlx/backend/rocm/utils.cpp index d79aa783e..1d4668b96 100644 --- a/mlx/backend/rocm/utils.cpp +++ b/mlx/backend/rocm/utils.cpp @@ -1,17 +1,46 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/rocm/utils.h" -#include -#include +#include "mlx/backend/rocm/device.h" +#include "mlx/dtype_utils.h" -namespace mlx::core::rocm { +#include -void check_hip_error(const char* msg, hipError_t error) { - if (error != hipSuccess) { - std::ostringstream oss; - oss << "[ROCm] " << msg << ": " << hipGetErrorString(error); - throw std::runtime_error(oss.str()); +namespace mlx::core { + +HipStream::HipStream(rocm::Device& device) { + device.make_current(); + CHECK_HIP_ERROR(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); +} + +HipStream::~HipStream() { + CHECK_HIP_ERROR(hipStreamDestroy(stream_)); +} + +void check_hip_error(const char* name, hipError_t err) { + if (err != hipSuccess) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, hipGetErrorString(err))); } } -} // namespace mlx::core::rocm \ No newline at end of file +const char* dtype_to_hip_type(const Dtype& dtype) { + if (dtype == float16) { + return "__half"; + } + if (dtype == bfloat16) { + return "__hip_bfloat16"; + } + if (dtype == complex64) { + return "hipFloatComplex"; + } +#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ + if (dtype == DTYPE) { \ + return #CPP_TYPE; \ + } + MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) +#undef SPECIALIZE_DtypeToString + return nullptr; +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/utils.h b/mlx/backend/rocm/utils.h index 20aab3836..679828896 100644 --- a/mlx/backend/rocm/utils.h +++ b/mlx/backend/rocm/utils.h @@ -1,12 +1,43 @@ // Copyright © 2025 Apple Inc. +// This file includes utilities that are used by C++ code (i.e. .cpp files). + #pragma once #include -namespace mlx::core::rocm { +namespace mlx::core { -// Utility function to check HIP errors -void check_hip_error(const char* msg, hipError_t error); +namespace rocm { +class Device; +} -} // namespace mlx::core::rocm \ No newline at end of file +struct Dtype; + +// HIP stream managed with RAII. +class HipStream { + public: + explicit HipStream(rocm::Device& device); + ~HipStream(); + + HipStream(const HipStream&) = delete; + HipStream& operator=(const HipStream&) = delete; + + operator hipStream_t() const { + return stream_; + } + + private: + hipStream_t stream_; +}; + +// Throw exception if the HIP API does not succeed. +void check_hip_error(const char* name, hipError_t err); + +// The macro version that prints the command that failed. +#define CHECK_HIP_ERROR(cmd) check_hip_error(#cmd, (cmd)) + +// Convert Dtype to HIP C++ types. +const char* dtype_to_hip_type(const Dtype& dtype); + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/rocm/worker.cpp b/mlx/backend/rocm/worker.cpp index 2dbbf98c7..db9d0b45b 100644 --- a/mlx/backend/rocm/worker.cpp +++ b/mlx/backend/rocm/worker.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/rocm/worker.h" +#include "mlx/backend/rocm/utils.h" namespace mlx::core::rocm { @@ -17,7 +18,7 @@ Worker::~Worker() { } } -void Worker::enqueue(std::function task) { +void Worker::add_task(std::function task) { { std::lock_guard lock(mutex_); tasks_.push(task); @@ -25,14 +26,28 @@ void Worker::enqueue(std::function task) { cv_.notify_one(); } -void Worker::commit() { - std::lock_guard lock(mutex_); - committed_ = true; +void Worker::consume_in_this_thread() { + std::queue> local_tasks; + { + std::lock_guard lock(mutex_); + local_tasks.swap(tasks_); + } + + while (!local_tasks.empty()) { + auto task = local_tasks.front(); + local_tasks.pop(); + task(); + } } -void Worker::join() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return tasks_.empty() && committed_; }); +void Worker::commit(hipStream_t stream) { + // Synchronize with stream and then process tasks + CHECK_HIP_ERROR(hipStreamSynchronize(stream)); + consume_in_this_thread(); +} + +void Worker::commit() { + cv_.notify_all(); } void Worker::worker_loop() { diff --git a/mlx/backend/rocm/worker.h b/mlx/backend/rocm/worker.h index a20b0effd..b41fb75c5 100644 --- a/mlx/backend/rocm/worker.h +++ b/mlx/backend/rocm/worker.h @@ -3,15 +3,16 @@ #pragma once #include + +#include #include -#include +#include #include #include namespace mlx::core::rocm { -using HipStream = hipStream_t; - +// Simple worker for async task execution synchronized with HIP streams. class Worker { public: Worker(); @@ -20,9 +21,17 @@ class Worker { Worker(const Worker&) = delete; Worker& operator=(const Worker&) = delete; - void enqueue(std::function task); + // Add a task to be executed + void add_task(std::function task); + + // Run pending tasks immediately in current thread. + void consume_in_this_thread(); + + // Commit tasks to be run after stream completion + void commit(hipStream_t stream); + + // Simple commit without stream dependency void commit(); - void join(); private: void worker_loop(); @@ -32,7 +41,6 @@ class Worker { std::mutex mutex_; std::condition_variable cv_; bool stop_{false}; - bool committed_{false}; }; } // namespace mlx::core::rocm \ No newline at end of file