mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
increment 1: few ops and jit update
This commit is contained in:
parent
8bb8b76ae4
commit
ac5adfa963
@ -1,36 +1,312 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/binary_ops.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
namespace mlx::core {
|
||||
|
||||
// Basic binary operation kernels will go here
|
||||
__global__ void add_kernel(float* a, float* b, float* c, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
c[idx] = a[idx] + b[idx];
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void multiply_kernel(float* a, float* b, float* c, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
c[idx] = a[idx] * b[idx];
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[0], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
void launch_add(float* a, float* b, float* c, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(add_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void launch_multiply(float* a, float* b, float* c, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(multiply_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(a[index], b[index]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const hip_array<int32_t, NDIM> shape,
|
||||
const hip_array<int64_t, NDIM> a_strides,
|
||||
const hip_array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const hip_array<int32_t, MAX_DIMS> shape,
|
||||
const hip_array<int64_t, MAX_DIMS> a_strides,
|
||||
const hip_array<int64_t, MAX_DIMS> b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
// Binary operation support checking
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
||||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
||||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
||||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
||||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
||||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
||||
return std::is_same_v<Out, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
std::is_same_v<Op, BitwiseXor>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (rocm::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = hip_type_t<CTYPE_IN>;
|
||||
using OutType = hip_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&rocm::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
make_hip_array<NDIM>(shape),
|
||||
make_hip_array<NDIM>(a_strides),
|
||||
make_hip_array<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = rocm::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large);
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
make_hip_array(shape),
|
||||
make_hip_array(a_strides),
|
||||
make_hip_array(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = rocm::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = rocm::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = rocm::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = rocm::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
std::vector<array> outputs{out};
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
auto& s = out.primitive().stream(); \
|
||||
binary_op_gpu<rocm::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
auto& s = outputs[0].primitive().stream(); \
|
||||
binary_op_gpu<rocm::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<rocm::NaNEqual>(inputs, out, op, s);
|
||||
} else {
|
||||
binary_op_gpu<rocm::Equal>(inputs, out, op, s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<rocm::BitwiseAnd>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<rocm::BitwiseOr>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<rocm::BitwiseXor>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<rocm::LeftShift>(inputs, out, op, s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<rocm::RightShift>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,20 +1,23 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
#include <fmt/format.h>
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device) {
|
||||
check_hip_error("hipStreamCreate", hipStreamCreate(&stream_));
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
check_hip_error("hipStreamSynchronize", hipStreamSynchronize(stream_));
|
||||
CHECK_HIP_ERROR(hipStreamSynchronize(stream_));
|
||||
}
|
||||
|
||||
hipStream_t DeviceStream::schedule_hip_stream() {
|
||||
// TODO: Return a stream that maximizes parallelism.
|
||||
return stream_;
|
||||
}
|
||||
|
||||
@ -23,22 +26,35 @@ hipStream_t DeviceStream::last_hip_stream() {
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
if (!encoder_) {
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
check_hip_error("hipSetDevice", hipSetDevice(device_));
|
||||
CHECK_HIP_ERROR(hipDeviceGetAttribute(
|
||||
&compute_capability_major_,
|
||||
hipDeviceAttributeComputeCapabilityMajor,
|
||||
device_));
|
||||
CHECK_HIP_ERROR(hipDeviceGetAttribute(
|
||||
&compute_capability_minor_,
|
||||
hipDeviceAttributeComputeCapabilityMinor,
|
||||
device_));
|
||||
|
||||
// Get device properties
|
||||
hipDeviceProp_t prop;
|
||||
check_hip_error(
|
||||
"hipGetDeviceProperties", hipGetDeviceProperties(&prop, device_));
|
||||
compute_capability_major_ = prop.major;
|
||||
compute_capability_minor_ = prop.minor;
|
||||
// Validate device requirements
|
||||
int attr = 0;
|
||||
CHECK_HIP_ERROR(hipDeviceGetAttribute(
|
||||
&attr, hipDeviceAttributeConcurrentManagedAccess, device_));
|
||||
if (attr != 1) {
|
||||
// ROCm unified memory might not be available on all devices
|
||||
// This is a warning rather than an error for ROCm
|
||||
// TODO: Add proper ROCm unified memory checking
|
||||
}
|
||||
|
||||
// Create rocBLAS handle
|
||||
check_hip_error(
|
||||
"rocblas_create_handle",
|
||||
make_current();
|
||||
CHECK_HIP_ERROR(
|
||||
static_cast<hipError_t>(rocblas_create_handle(&rocblas_handle_)));
|
||||
}
|
||||
|
||||
@ -49,56 +65,66 @@ Device::~Device() {
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
check_hip_error("hipSetDevice", hipSetDevice(device_));
|
||||
// Cache current device to reduce HIP API calls
|
||||
static int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_HIP_ERROR(hipSetDevice(device_));
|
||||
current = device_;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it != streams_.end()) {
|
||||
return it->second;
|
||||
if (it == streams_.end()) {
|
||||
it = streams_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
|
||||
auto [new_it, inserted] = streams_.emplace(s.index, DeviceStream(*this));
|
||||
return new_it->second;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& stream)
|
||||
: device_(stream.device()), stream_(stream), worker_() {}
|
||||
CommandEncoder::CommandEncoder(DeviceStream& s)
|
||||
: device_(s.device()), stream_(s) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.enqueue(task);
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
// Implementation for ending encoding
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
|
||||
// There is no kernel running, run completion handlers immediately.
|
||||
if (!has_gpu_work_) {
|
||||
worker_.consume_in_this_thread();
|
||||
return;
|
||||
}
|
||||
has_gpu_work_ = false;
|
||||
|
||||
// Commit tasks
|
||||
commit();
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit();
|
||||
worker_.commit(stream_.last_hip_stream());
|
||||
}
|
||||
|
||||
// Global device management
|
||||
static std::unordered_map<int, std::unique_ptr<Device>> devices_;
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
auto it = devices_.find(device.index);
|
||||
if (it != devices_.end()) {
|
||||
return *it->second;
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
if (it == devices.end()) {
|
||||
it = devices.try_emplace(device.index, device.index).first;
|
||||
}
|
||||
|
||||
auto new_device = std::make_unique<Device>(device.index);
|
||||
Device& dev_ref = *new_device;
|
||||
devices_[device.index] = std::move(new_device);
|
||||
return dev_ref;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
// Use default device (index 0) for now
|
||||
return device(mlx::core::Device{mlx::core::Device::gpu, 0}).get_stream(s);
|
||||
return device(s.device).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
} // namespace rocm
|
||||
|
||||
} // namespace mlx::core
|
@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
@ -11,7 +12,9 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
@ -138,4 +141,6 @@ CommandEncoder& get_command_encoder(Stream s);
|
||||
// Utility function to check HIP errors
|
||||
void check_hip_error(const char* msg, hipError_t error);
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
} // namespace rocm
|
||||
|
||||
} // namespace mlx::core
|
217
mlx/backend/rocm/device/binary_ops.hpp
Normal file
217
mlx/backend/rocm/device/binary_ops.hpp
Normal file
@ -0,0 +1,217 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipcomplex.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Arithmetic operations
|
||||
struct Add {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a / b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return powf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return pow(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return fmodf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return fmod(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// Comparison operations
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a == b;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a != b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a > b;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a >= b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return a <= b;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T a, T b) {
|
||||
return (isnan(a) && isnan(b)) || (a == b);
|
||||
}
|
||||
};
|
||||
|
||||
// Logic operations
|
||||
struct LogicalAnd {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
// Math operations
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return fmaxf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return fmax(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return fminf(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return fmin(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
T max_val = fmaxf(a, b);
|
||||
T min_val = fminf(a, b);
|
||||
if (isinf(max_val)) {
|
||||
return max_val;
|
||||
}
|
||||
return max_val + log1pf(expf(min_val - max_val));
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
double max_val = fmax(a, b);
|
||||
double min_val = fmin(a, b);
|
||||
if (isinf(max_val)) {
|
||||
return max_val;
|
||||
}
|
||||
return max_val + log1p(exp(min_val - max_val));
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return atan2f(a, b);
|
||||
}
|
||||
|
||||
__device__ double operator()(double a, double b) {
|
||||
return atan2(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// Bitwise operations
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a & b;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a | b;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a ^ b;
|
||||
}
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a << b;
|
||||
}
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a >> b;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
50
mlx/backend/rocm/event.cpp
Normal file
50
mlx/backend/rocm/event.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/event.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
HipEvent::HipEvent() {
|
||||
CHECK_HIP_ERROR(hipEventCreate(&event_));
|
||||
}
|
||||
|
||||
HipEvent::~HipEvent() {
|
||||
CHECK_HIP_ERROR(hipEventDestroy(event_));
|
||||
}
|
||||
|
||||
void HipEvent::record(hipStream_t stream) {
|
||||
CHECK_HIP_ERROR(hipEventRecord(event_, stream));
|
||||
}
|
||||
|
||||
void HipEvent::wait() {
|
||||
CHECK_HIP_ERROR(hipEventSynchronize(event_));
|
||||
}
|
||||
|
||||
bool HipEvent::query() const {
|
||||
hipError_t status = hipEventQuery(event_);
|
||||
if (status == hipSuccess) {
|
||||
return true;
|
||||
} else if (status == hipErrorNotReady) {
|
||||
return false;
|
||||
} else {
|
||||
CHECK_HIP_ERROR(status);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
SharedEvent::SharedEvent() = default;
|
||||
|
||||
void SharedEvent::notify() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ready_ = true;
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
void SharedEvent::wait() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [this] { return ready_; });
|
||||
ready_ = false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
48
mlx/backend/rocm/event.h
Normal file
48
mlx/backend/rocm/event.h
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP event managed with RAII.
|
||||
class HipEvent {
|
||||
public:
|
||||
HipEvent();
|
||||
~HipEvent();
|
||||
|
||||
HipEvent(const HipEvent&) = delete;
|
||||
HipEvent& operator=(const HipEvent&) = delete;
|
||||
|
||||
void record(hipStream_t stream);
|
||||
void wait();
|
||||
bool query() const;
|
||||
|
||||
operator hipEvent_t() const {
|
||||
return event_;
|
||||
}
|
||||
|
||||
private:
|
||||
hipEvent_t event_;
|
||||
};
|
||||
|
||||
// Shared event for worker thread synchronization.
|
||||
class SharedEvent {
|
||||
public:
|
||||
SharedEvent();
|
||||
|
||||
void notify();
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
bool ready_{false};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
167
mlx/backend/rocm/jit_module.cpp
Normal file
167
mlx/backend/rocm/jit_module.cpp
Normal file
@ -0,0 +1,167 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/jit_module.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
JitModule::JitModule(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags,
|
||||
bool verbose) {
|
||||
compile(kernel_name, kernel_source, template_args, compiler_flags, verbose);
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
if (kernel_) {
|
||||
// No hipFunctionDestroy equivalent in HIP
|
||||
}
|
||||
if (module_) {
|
||||
CHECK_HIP_ERROR(hipModuleUnload(module_));
|
||||
}
|
||||
if (program_) {
|
||||
hiprtcDestroyProgram(&program_);
|
||||
}
|
||||
}
|
||||
|
||||
void JitModule::compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags,
|
||||
bool verbose) {
|
||||
// Create HIPRTC program
|
||||
CHECK_HIP_ERROR(hiprtcCreateProgram(
|
||||
&program_,
|
||||
kernel_source.c_str(),
|
||||
kernel_name.c_str(),
|
||||
0,
|
||||
nullptr,
|
||||
nullptr));
|
||||
|
||||
// Build compiler options
|
||||
std::vector<const char*> options;
|
||||
std::vector<std::string> option_strings;
|
||||
|
||||
// Add default options
|
||||
option_strings.push_back("--std=c++17");
|
||||
option_strings.push_back("-O3");
|
||||
option_strings.push_back("-DMLX_USE_ROCM");
|
||||
|
||||
// Add user-provided flags
|
||||
for (const auto& flag : compiler_flags) {
|
||||
option_strings.push_back(flag);
|
||||
}
|
||||
|
||||
// Add template arguments
|
||||
for (const auto& arg : template_args) {
|
||||
option_strings.push_back("-D" + arg);
|
||||
}
|
||||
|
||||
// Convert to char* array
|
||||
for (const auto& option : option_strings) {
|
||||
options.push_back(option.c_str());
|
||||
}
|
||||
|
||||
// Compile the program
|
||||
hiprtcResult compile_result =
|
||||
hiprtcCompileProgram(program_, options.size(), options.data());
|
||||
|
||||
// Get compilation log
|
||||
size_t log_size;
|
||||
CHECK_HIP_ERROR(hiprtcGetProgramLogSize(program_, &log_size));
|
||||
|
||||
if (log_size > 1) {
|
||||
std::vector<char> log(log_size);
|
||||
CHECK_HIP_ERROR(hiprtcGetProgramLog(program_, log.data()));
|
||||
|
||||
if (verbose || compile_result != HIPRTC_SUCCESS) {
|
||||
fmt::print(
|
||||
"HIPRTC compilation log for {}:\n{}\n", kernel_name, log.data());
|
||||
}
|
||||
}
|
||||
|
||||
if (compile_result != HIPRTC_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("HIPRTC compilation failed for kernel {}", kernel_name));
|
||||
}
|
||||
|
||||
// Get compiled code
|
||||
size_t code_size;
|
||||
CHECK_HIP_ERROR(hiprtcGetCodeSize(program_, &code_size));
|
||||
|
||||
std::vector<char> code(code_size);
|
||||
CHECK_HIP_ERROR(hiprtcGetCode(program_, code.data()));
|
||||
|
||||
// Load module
|
||||
CHECK_HIP_ERROR(hipModuleLoadData(&module_, code.data()));
|
||||
|
||||
// Get kernel function
|
||||
CHECK_HIP_ERROR(hipModuleGetFunction(&kernel_, module_, kernel_name.c_str()));
|
||||
}
|
||||
|
||||
JitCache& JitCache::instance() {
|
||||
static JitCache cache;
|
||||
return cache;
|
||||
}
|
||||
|
||||
std::shared_ptr<JitModule> JitCache::get_or_create(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) {
|
||||
std::string key =
|
||||
make_key(kernel_name, kernel_source, template_args, compiler_flags);
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
auto it = cache_.find(key);
|
||||
if (it != cache_.end()) {
|
||||
if (auto module = it->second.lock()) {
|
||||
return module;
|
||||
} else {
|
||||
cache_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
auto module = std::make_shared<JitModule>(
|
||||
kernel_name, kernel_source, template_args, compiler_flags);
|
||||
cache_[key] = module;
|
||||
return module;
|
||||
}
|
||||
|
||||
std::string JitCache::make_key(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) const {
|
||||
std::ostringstream oss;
|
||||
oss << kernel_name << "|" << kernel_source;
|
||||
|
||||
for (const auto& arg : template_args) {
|
||||
oss << "|" << arg;
|
||||
}
|
||||
|
||||
for (const auto& flag : compiler_flags) {
|
||||
oss << "|" << flag;
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::shared_ptr<JitModule> make_jit_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) {
|
||||
return JitCache::instance().get_or_create(
|
||||
kernel_name, kernel_source, template_args, compiler_flags);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
100
mlx/backend/rocm/jit_module.h
Normal file
100
mlx/backend/rocm/jit_module.h
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hiprtc.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// JIT compilation module for ROCm
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args = {},
|
||||
const std::vector<std::string>& compiler_flags = {},
|
||||
bool verbose = false);
|
||||
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
|
||||
// Get the compiled kernel function
|
||||
hipFunction_t get_kernel() const {
|
||||
return kernel_;
|
||||
}
|
||||
|
||||
// Launch the kernel with given arguments
|
||||
template <typename... Args>
|
||||
void launch(
|
||||
dim3 grid_dims,
|
||||
dim3 block_dims,
|
||||
size_t shared_memory,
|
||||
hipStream_t stream,
|
||||
Args&&... args) {
|
||||
void* kernel_args[] = {(void*)&args...};
|
||||
CHECK_HIP_ERROR(hipModuleLaunchKernel(
|
||||
kernel_,
|
||||
grid_dims.x,
|
||||
grid_dims.y,
|
||||
grid_dims.z,
|
||||
block_dims.x,
|
||||
block_dims.y,
|
||||
block_dims.z,
|
||||
shared_memory,
|
||||
stream,
|
||||
kernel_args,
|
||||
nullptr));
|
||||
}
|
||||
|
||||
private:
|
||||
void compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags,
|
||||
bool verbose);
|
||||
|
||||
hiprtcProgram program_{nullptr};
|
||||
hipModule_t module_{nullptr};
|
||||
hipFunction_t kernel_{nullptr};
|
||||
};
|
||||
|
||||
// JIT cache for compiled modules
|
||||
class JitCache {
|
||||
public:
|
||||
static JitCache& instance();
|
||||
|
||||
std::shared_ptr<JitModule> get_or_create(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args = {},
|
||||
const std::vector<std::string>& compiler_flags = {});
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::weak_ptr<JitModule>> cache_;
|
||||
std::mutex mutex_;
|
||||
|
||||
std::string make_key(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args,
|
||||
const std::vector<std::string>& compiler_flags) const;
|
||||
};
|
||||
|
||||
// Helper function to create and cache JIT modules
|
||||
std::shared_ptr<JitModule> make_jit_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& kernel_source,
|
||||
const std::vector<std::string>& template_args = {},
|
||||
const std::vector<std::string>& compiler_flags = {});
|
||||
|
||||
} // namespace mlx::core::rocm
|
135
mlx/backend/rocm/kernel_utils.hpp
Normal file
135
mlx/backend/rocm/kernel_utils.hpp
Normal file
@ -0,0 +1,135 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <array>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Constants
|
||||
constexpr int MAX_DIMS = 8;
|
||||
|
||||
// HIP array type for passing arrays to kernels
|
||||
template <typename T, int N>
|
||||
using hip_array = std::array<T, N>;
|
||||
|
||||
// Helper to create hip_array from vector
|
||||
template <int N, typename T>
|
||||
__host__ hip_array<T, N> make_hip_array(const std::vector<T>& vec) {
|
||||
hip_array<T, N> arr;
|
||||
for (int i = 0; i < N && i < vec.size(); ++i) {
|
||||
arr[i] = vec[i];
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ hip_array<T, MAX_DIMS> make_hip_array(const std::vector<T>& vec) {
|
||||
return make_hip_array<MAX_DIMS>(vec);
|
||||
}
|
||||
|
||||
// Type mapping from MLX types to HIP types
|
||||
template <typename T>
|
||||
using hip_type_t = T;
|
||||
|
||||
template <>
|
||||
using hip_type_t<float16> = __half;
|
||||
|
||||
template <>
|
||||
using hip_type_t<bfloat16> = __hip_bfloat16;
|
||||
|
||||
template <>
|
||||
using hip_type_t<complex64> = hipFloatComplex;
|
||||
|
||||
// Element to location mapping for general broadcasting
|
||||
template <int NDIM>
|
||||
__device__ std::pair<int64_t, int64_t> elem_to_loc_nd(
|
||||
int64_t elem,
|
||||
const int32_t* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
int64_t a_idx = 0;
|
||||
int64_t b_idx = 0;
|
||||
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int64_t pos_in_dim = elem % shape[i];
|
||||
elem /= shape[i];
|
||||
a_idx += pos_in_dim * a_strides[i];
|
||||
b_idx += pos_in_dim * b_strides[i];
|
||||
}
|
||||
|
||||
return {a_idx, b_idx};
|
||||
}
|
||||
|
||||
// 4D specialization for performance
|
||||
__device__ inline std::pair<int64_t, int64_t> elem_to_loc_4d(
|
||||
int64_t elem,
|
||||
const int32_t* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
int64_t a_idx = 0;
|
||||
int64_t b_idx = 0;
|
||||
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int64_t pos_in_dim = elem % shape[i];
|
||||
elem /= shape[i];
|
||||
a_idx += pos_in_dim * a_strides[i];
|
||||
b_idx += pos_in_dim * b_strides[i];
|
||||
}
|
||||
|
||||
return {a_idx, b_idx};
|
||||
}
|
||||
|
||||
// Launch configuration calculation
|
||||
template <typename Kernel>
|
||||
std::pair<dim3, dim3>
|
||||
get_launch_args(Kernel kernel, const array& out, bool large = false) {
|
||||
int threads_per_block = 256;
|
||||
int64_t total_threads = out.size();
|
||||
|
||||
if (large) {
|
||||
// For large arrays, use more blocks
|
||||
int64_t blocks =
|
||||
(total_threads + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
} else {
|
||||
int blocks = (total_threads + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
std::pair<dim3, dim3> get_launch_args(
|
||||
Kernel kernel,
|
||||
int64_t size,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
bool large = false) {
|
||||
int threads_per_block = 256;
|
||||
|
||||
if (large) {
|
||||
int64_t blocks = (size + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
} else {
|
||||
int blocks = (size + threads_per_block - 1) / threads_per_block;
|
||||
return {dim3(blocks), dim3(threads_per_block)};
|
||||
}
|
||||
}
|
||||
|
||||
// Cooperative groups thread rank equivalent
|
||||
namespace cooperative_groups {
|
||||
class grid_group {
|
||||
public:
|
||||
__device__ int64_t thread_rank() const {
|
||||
return blockIdx.x * blockDim.x + threadIdx.x;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ grid_group this_grid() {
|
||||
return grid_group{};
|
||||
}
|
||||
} // namespace cooperative_groups
|
||||
|
||||
} // namespace mlx::core::rocm
|
@ -1,17 +1,46 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
#include <fmt/format.h>
|
||||
|
||||
void check_hip_error(const char* msg, hipError_t error) {
|
||||
if (error != hipSuccess) {
|
||||
std::ostringstream oss;
|
||||
oss << "[ROCm] " << msg << ": " << hipGetErrorString(error);
|
||||
throw std::runtime_error(oss.str());
|
||||
namespace mlx::core {
|
||||
|
||||
HipStream::HipStream(rocm::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_HIP_ERROR(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking));
|
||||
}
|
||||
|
||||
HipStream::~HipStream() {
|
||||
CHECK_HIP_ERROR(hipStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
void check_hip_error(const char* name, hipError_t err) {
|
||||
if (err != hipSuccess) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, hipGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
const char* dtype_to_hip_type(const Dtype& dtype) {
|
||||
if (dtype == float16) {
|
||||
return "__half";
|
||||
}
|
||||
if (dtype == bfloat16) {
|
||||
return "__hip_bfloat16";
|
||||
}
|
||||
if (dtype == complex64) {
|
||||
return "hipFloatComplex";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
}
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||
#undef SPECIALIZE_DtypeToString
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,12 +1,43 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file includes utilities that are used by C++ code (i.e. .cpp files).
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
namespace mlx::core {
|
||||
|
||||
// Utility function to check HIP errors
|
||||
void check_hip_error(const char* msg, hipError_t error);
|
||||
namespace rocm {
|
||||
class Device;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
struct Dtype;
|
||||
|
||||
// HIP stream managed with RAII.
|
||||
class HipStream {
|
||||
public:
|
||||
explicit HipStream(rocm::Device& device);
|
||||
~HipStream();
|
||||
|
||||
HipStream(const HipStream&) = delete;
|
||||
HipStream& operator=(const HipStream&) = delete;
|
||||
|
||||
operator hipStream_t() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
hipStream_t stream_;
|
||||
};
|
||||
|
||||
// Throw exception if the HIP API does not succeed.
|
||||
void check_hip_error(const char* name, hipError_t err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_HIP_ERROR(cmd) check_hip_error(#cmd, (cmd))
|
||||
|
||||
// Convert Dtype to HIP C++ types.
|
||||
const char* dtype_to_hip_type(const Dtype& dtype);
|
||||
|
||||
} // namespace mlx::core
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
@ -17,7 +18,7 @@ Worker::~Worker() {
|
||||
}
|
||||
}
|
||||
|
||||
void Worker::enqueue(std::function<void()> task) {
|
||||
void Worker::add_task(std::function<void()> task) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
tasks_.push(task);
|
||||
@ -25,14 +26,28 @@ void Worker::enqueue(std::function<void()> task) {
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
committed_ = true;
|
||||
void Worker::consume_in_this_thread() {
|
||||
std::queue<std::function<void()>> local_tasks;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
local_tasks.swap(tasks_);
|
||||
}
|
||||
|
||||
while (!local_tasks.empty()) {
|
||||
auto task = local_tasks.front();
|
||||
local_tasks.pop();
|
||||
task();
|
||||
}
|
||||
}
|
||||
|
||||
void Worker::join() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [this] { return tasks_.empty() && committed_; });
|
||||
void Worker::commit(hipStream_t stream) {
|
||||
// Synchronize with stream and then process tasks
|
||||
CHECK_HIP_ERROR(hipStreamSynchronize(stream));
|
||||
consume_in_this_thread();
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void Worker::worker_loop() {
|
||||
|
@ -3,15 +3,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
using HipStream = hipStream_t;
|
||||
|
||||
// Simple worker for async task execution synchronized with HIP streams.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
@ -20,9 +21,17 @@ class Worker {
|
||||
Worker(const Worker&) = delete;
|
||||
Worker& operator=(const Worker&) = delete;
|
||||
|
||||
void enqueue(std::function<void()> task);
|
||||
// Add a task to be executed
|
||||
void add_task(std::function<void()> task);
|
||||
|
||||
// Run pending tasks immediately in current thread.
|
||||
void consume_in_this_thread();
|
||||
|
||||
// Commit tasks to be run after stream completion
|
||||
void commit(hipStream_t stream);
|
||||
|
||||
// Simple commit without stream dependency
|
||||
void commit();
|
||||
void join();
|
||||
|
||||
private:
|
||||
void worker_loop();
|
||||
@ -32,7 +41,6 @@ class Worker {
|
||||
std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
bool stop_{false};
|
||||
bool committed_{false};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
Loading…
Reference in New Issue
Block a user