mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-23 16:46:50 +08:00
Increment 2: Implement major ops and add structure similar to cuda
This commit is contained in:
parent
ac5adfa963
commit
cc4de6a607
@ -2,19 +2,205 @@
|
||||
|
||||
#include "mlx/backend/rocm/allocator.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
#include <fmt/format.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <unistd.h>
|
||||
|
||||
void* allocate(size_t size) {
|
||||
void* ptr;
|
||||
check_hip_error("hipMalloc", hipMalloc(&ptr, size));
|
||||
return ptr;
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
RocmAllocator::RocmAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
[](RocmBuffer* buf) { return buf->size; },
|
||||
[this](RocmBuffer* buf) {
|
||||
rocm_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_HIP_ERROR(hipMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
void deallocate(void* ptr) {
|
||||
if (ptr) {
|
||||
check_hip_error("hipFree", hipFree(ptr));
|
||||
Buffer RocmAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
std::unique_lock lock(mutex_);
|
||||
RocmBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
if (mem_required >= memory_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
buf = new RocmBuffer{nullptr, size};
|
||||
hipError_t err = hipMallocManaged(&buf->data, size);
|
||||
if (err != hipSuccess && err != hipErrorMemoryAllocation) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("hipMallocManaged failed: {}.", hipGetErrorString(err)));
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit.
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
void RocmAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
rocm_free(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
size_t RocmAllocator::size(Buffer buffer) const {
|
||||
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return 0;
|
||||
}
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void RocmAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void RocmAllocator::rocm_free(void* buf) {
|
||||
// If rocm_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->rocm_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
hipFree(buf);
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
}
|
||||
|
||||
void RocmAllocator::reset_peak_memory() {
|
||||
std::lock_guard lock(mutex_);
|
||||
peak_memory_ = 0;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
size_t RocmAllocator::get_cache_memory() const {
|
||||
return buffer_cache_.cache_size();
|
||||
}
|
||||
|
||||
size_t RocmAllocator::set_cache_limit(size_t limit) {
|
||||
std::lock_guard lk(mutex_);
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
void RocmAllocator::clear_cache() {
|
||||
std::lock_guard lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
RocmAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of RocmAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static RocmAllocator* allocator_ = new RocmAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return rocm::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<rocm::RocmBuffer*>(ptr_)->data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return rocm::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return rocm::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
return rocm::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return rocm::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return rocm::allocator().get_memory_limit();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return rocm::allocator().get_cache_memory();
|
||||
}
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return rocm::allocator().set_cache_limit(limit);
|
||||
}
|
||||
void clear_cache() {
|
||||
rocm::allocator().clear_cache();
|
||||
}
|
||||
|
||||
// Not supported in ROCm.
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -2,11 +2,66 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void* allocate(size_t size);
|
||||
void deallocate(void* ptr);
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores ROCm-managed unified memory.
|
||||
struct RocmBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class RocmAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In ROCm freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
// Call hipFree in the safe thread.
|
||||
void rocm_free(void* buf);
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_cache_memory() const;
|
||||
size_t set_cache_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
RocmAllocator();
|
||||
friend RocmAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<RocmBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
||||
RocmAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::rocm
|
60
mlx/backend/rocm/copy/copy.hpp
Normal file
60
mlx/backend/rocm/copy/copy.hpp
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Copy function declarations
|
||||
void copy_contiguous(
|
||||
const void* src,
|
||||
void* dst,
|
||||
size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
void copy_general(
|
||||
const void* src,
|
||||
void* dst,
|
||||
const int* src_shape,
|
||||
const size_t* src_strides,
|
||||
const int* dst_shape,
|
||||
const size_t* dst_strides,
|
||||
int ndim,
|
||||
size_t size,
|
||||
size_t dtype_size,
|
||||
hipStream_t stream);
|
||||
|
||||
void copy_general_dynamic(
|
||||
const void* src,
|
||||
void* dst,
|
||||
const int* src_shape,
|
||||
const size_t* src_strides,
|
||||
const int* dst_shape,
|
||||
const size_t* dst_strides,
|
||||
int ndim,
|
||||
size_t size,
|
||||
size_t dtype_size,
|
||||
hipStream_t stream);
|
||||
|
||||
void copy_general_input(
|
||||
const void* src,
|
||||
void* dst,
|
||||
const int* src_shape,
|
||||
const size_t* src_strides,
|
||||
const int* dst_shape,
|
||||
const size_t* dst_strides,
|
||||
int ndim,
|
||||
size_t size,
|
||||
size_t dtype_size,
|
||||
hipStream_t stream);
|
||||
|
||||
// Utility functions for element location calculation
|
||||
__device__ size_t
|
||||
elem_to_loc(size_t elem, const int* shape, const size_t* strides, int ndim);
|
||||
|
||||
__device__ size_t
|
||||
loc_to_elem(size_t loc, const int* shape, const size_t* strides, int ndim);
|
||||
|
||||
} // namespace mlx::core::rocm
|
38
mlx/backend/rocm/copy/copy_contiguous.hip
Normal file
38
mlx/backend/rocm/copy/copy_contiguous.hip
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/copy/copy.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void copy_contiguous_kernel(
|
||||
const char* src,
|
||||
char* dst,
|
||||
size_t size) {
|
||||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < size) {
|
||||
dst[tid] = src[tid];
|
||||
}
|
||||
}
|
||||
|
||||
void copy_contiguous(
|
||||
const void* src,
|
||||
void* dst,
|
||||
size_t size,
|
||||
hipStream_t stream) {
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int threads_per_block = 256;
|
||||
const int blocks = (size + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
copy_contiguous_kernel<<<blocks, threads_per_block, 0, stream>>>(
|
||||
static_cast<const char*>(src),
|
||||
static_cast<char*>(dst),
|
||||
size);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
17
mlx/backend/rocm/device/arange.hpp
Normal file
17
mlx/backend/rocm/device/arange.hpp
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename T>
|
||||
__global__ void arange_kernel(T* out, T start, T step, size_t size) {
|
||||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < size) {
|
||||
out[tid] = start + static_cast<T>(tid) * step;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
36
mlx/backend/rocm/device/atomic_ops.hpp
Normal file
36
mlx/backend/rocm/device/atomic_ops.hpp
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Atomic operations for HIP
|
||||
__device__ inline float atomicAddFloat(float* address, float val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline double atomicAddDouble(double* address, double val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline int atomicAddInt(int* address, int val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline unsigned int atomicAddUInt(
|
||||
unsigned int* address,
|
||||
unsigned int val) {
|
||||
return atomicAdd(address, val);
|
||||
}
|
||||
|
||||
__device__ inline float atomicMaxFloat(float* address, float val) {
|
||||
return atomicMax(address, val);
|
||||
}
|
||||
|
||||
__device__ inline float atomicMinFloat(float* address, float val) {
|
||||
return atomicMin(address, val);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
21
mlx/backend/rocm/device/cast_op.hpp
Normal file
21
mlx/backend/rocm/device/cast_op.hpp
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename To, typename From>
|
||||
struct CastOp {
|
||||
__device__ To operator()(From x) const {
|
||||
return static_cast<To>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename To, typename From>
|
||||
__device__ inline To cast_op(From x) {
|
||||
return static_cast<To>(x);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
14
mlx/backend/rocm/device/config.h
Normal file
14
mlx/backend/rocm/device/config.h
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
// ROCm/HIP specific configuration
|
||||
#define ROCM_MAX_THREADS_PER_BLOCK 1024
|
||||
#define ROCM_WARP_SIZE 64
|
||||
#define ROCM_MAX_BLOCKS_PER_GRID 65535
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
constexpr int kMaxThreadsPerBlock = ROCM_MAX_THREADS_PER_BLOCK;
|
||||
constexpr int kWarpSize = ROCM_WARP_SIZE;
|
||||
constexpr int kMaxBlocksPerGrid = ROCM_MAX_BLOCKS_PER_GRID;
|
||||
} // namespace mlx::core::rocm
|
87
mlx/backend/rocm/device/fp16_math.hpp
Normal file
87
mlx/backend/rocm/device/fp16_math.hpp
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP/ROCm equivalents of CUDA half precision math functions
|
||||
inline __device__ __half2 h2sin(__half2 x) {
|
||||
return __half2{hsin(x.x), hsin(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2cos(__half2 x) {
|
||||
return __half2{hcos(x.x), hcos(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2exp(__half2 x) {
|
||||
return __half2{hexp(x.x), hexp(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2log(__half2 x) {
|
||||
return __half2{hlog(x.x), hlog(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2sqrt(__half2 x) {
|
||||
return __half2{hsqrt(x.x), hsqrt(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2rsqrt(__half2 x) {
|
||||
return __half2{hrsqrt(x.x), hrsqrt(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2ceil(__half2 x) {
|
||||
return __half2{hceil(x.x), hceil(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2floor(__half2 x) {
|
||||
return __half2{hfloor(x.x), hfloor(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2rint(__half2 x) {
|
||||
return __half2{hrint(x.x), hrint(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2trunc(__half2 x) {
|
||||
return __half2{htrunc(x.x), htrunc(x.y)};
|
||||
}
|
||||
|
||||
// Additional math functions for half precision
|
||||
inline __device__ __half habs(__half x) {
|
||||
return __half{fabsf(__half2float(x))};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2abs(__half2 x) {
|
||||
return __half2{habs(x.x), habs(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __half hneg(__half x) {
|
||||
return __half{-__half2float(x)};
|
||||
}
|
||||
|
||||
inline __device__ __half2 h2neg(__half2 x) {
|
||||
return __half2{hneg(x.x), hneg(x.y)};
|
||||
}
|
||||
|
||||
// BFloat16 support functions
|
||||
#ifdef __HIP_BFLOAT16__
|
||||
inline __device__ __hip_bfloat16 habs(__hip_bfloat16 x) {
|
||||
return __hip_bfloat16{fabsf(__bfloat162float(x))};
|
||||
}
|
||||
|
||||
inline __device__ __hip_bfloat162 h2abs(__hip_bfloat162 x) {
|
||||
return __hip_bfloat162{habs(x.x), habs(x.y)};
|
||||
}
|
||||
|
||||
inline __device__ __hip_bfloat16 hneg(__hip_bfloat16 x) {
|
||||
return __hip_bfloat16{-__bfloat162float(x)};
|
||||
}
|
||||
|
||||
inline __device__ __hip_bfloat162 h2neg(__hip_bfloat162 x) {
|
||||
return __hip_bfloat162{hneg(x.x), hneg(x.y)};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core::rocm
|
52
mlx/backend/rocm/device/hip_complex_math.hpp
Normal file
52
mlx/backend/rocm/device/hip_complex_math.hpp
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_complex.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP complex math functions
|
||||
__device__ inline hipFloatComplex hip_complex_add(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
return make_hipFloatComplex(
|
||||
hipCrealf(a) + hipCrealf(b), hipCimagf(a) + hipCimagf(b));
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_sub(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
return make_hipFloatComplex(
|
||||
hipCrealf(a) - hipCrealf(b), hipCimagf(a) - hipCimagf(b));
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_mul(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
float real = hipCrealf(a) * hipCrealf(b) - hipCimagf(a) * hipCimagf(b);
|
||||
float imag = hipCrealf(a) * hipCimagf(b) + hipCimagf(a) * hipCrealf(b);
|
||||
return make_hipFloatComplex(real, imag);
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_div(
|
||||
hipFloatComplex a,
|
||||
hipFloatComplex b) {
|
||||
float denom = hipCrealf(b) * hipCrealf(b) + hipCimagf(b) * hipCimagf(b);
|
||||
float real =
|
||||
(hipCrealf(a) * hipCrealf(b) + hipCimagf(a) * hipCimagf(b)) / denom;
|
||||
float imag =
|
||||
(hipCimagf(a) * hipCrealf(b) - hipCrealf(a) * hipCimagf(b)) / denom;
|
||||
return make_hipFloatComplex(real, imag);
|
||||
}
|
||||
|
||||
__device__ inline float hip_complex_abs(hipFloatComplex z) {
|
||||
return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z));
|
||||
}
|
||||
|
||||
__device__ inline hipFloatComplex hip_complex_conj(hipFloatComplex z) {
|
||||
return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
16
mlx/backend/rocm/device/ternary_ops.hpp
Normal file
16
mlx/backend/rocm/device/ternary_ops.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
__device__ T operator()(bool condition, T a, T b) const {
|
||||
return condition ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
368
mlx/backend/rocm/device/unary_ops.hpp
Normal file
368
mlx/backend/rocm/device/unary_ops.hpp
Normal file
@ -0,0 +1,368 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/rocm/device/fp16_math.hpp"
|
||||
#include "mlx/backend/rocm/device/utils.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
sqrt(hipCrealf(x) * hipCrealf(x) + hipCimagf(x) * hipCimagf(x)), 0};
|
||||
} else {
|
||||
return abs(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return ~x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return ceil(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
__device__ hipFloatComplex operator()(hipFloatComplex x) {
|
||||
return {hipCrealf(x), -hipCimagf(x)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
cos(hipCrealf(x)) * cosh(hipCimagf(x)),
|
||||
-sin(hipCrealf(x)) * sinh(hipCimagf(x))};
|
||||
} else {
|
||||
return cos(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
cosh(hipCrealf(x)) * cos(hipCimagf(x)),
|
||||
sinh(hipCrealf(x)) * sin(hipCimagf(x))};
|
||||
} else {
|
||||
return cosh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return erf(__half2float(x));
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return erf(__bfloat162float(x));
|
||||
} else {
|
||||
return erf(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return erfinv(__half2float(x));
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return erfinv(__bfloat162float(x));
|
||||
} else {
|
||||
return erfinv(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto m = exp(hipCrealf(x));
|
||||
return {m * cos(hipCimagf(x)), m * sinh(hipCimagf(x))};
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return expm1(__half2float(x));
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return expm1(__bfloat162float(x));
|
||||
} else {
|
||||
return expm1(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return floor(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
__device__ float operator()(hipFloatComplex x) {
|
||||
return hipCimagf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto r = log(hipCrealf(Abs{}(x)));
|
||||
auto i = atan2f(hipCimagf(x), hipCrealf(x));
|
||||
return {r, i};
|
||||
} else {
|
||||
return log(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {hipCrealf(y) / M_LN2, hipCimagf(y) / M_LN2};
|
||||
} else {
|
||||
return log2(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {hipCrealf(y) / M_LN10, hipCimagf(y) / M_LN10};
|
||||
} else {
|
||||
return log10(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
__device__ bool operator()(bool x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return 0 - x;
|
||||
} else {
|
||||
return -x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
__device__ float operator()(hipFloatComplex x) {
|
||||
return hipCrealf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {rint(hipCrealf(x)), rint(hipCimagf(x))};
|
||||
} else {
|
||||
return rint(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != 0;
|
||||
} else if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
if (hipCrealf(x) == 0 && hipCimagf(x) == 0) {
|
||||
return x;
|
||||
} else {
|
||||
return x / Abs()(x);
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||
} else {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
sin(hipCrealf(x)) * cosh(hipCimagf(x)),
|
||||
cos(hipCrealf(x)) * sinh(hipCimagf(x))};
|
||||
} else {
|
||||
return sin(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
return {
|
||||
sinh(hipCrealf(x)) * cos(hipCimagf(x)),
|
||||
cosh(hipCrealf(x)) * sin(hipCimagf(x))};
|
||||
} else {
|
||||
return sinh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
float tan_a = tan(hipCrealf(x));
|
||||
float tanh_b = tanh(hipCimagf(x));
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
} else {
|
||||
return tan(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (std::is_same_v<T, hipFloatComplex>) {
|
||||
float tanh_a = tanh(hipCrealf(x));
|
||||
float tan_b = tan(hipCimagf(x));
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
} else {
|
||||
return tanh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
173
mlx/backend/rocm/device/utils.hpp
Normal file
173
mlx/backend/rocm/device/utils.hpp
Normal file
@ -0,0 +1,173 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_complex.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// HIP/ROCm type definitions
|
||||
using hip_complex = hipFloatComplex;
|
||||
|
||||
// Utility functions for HIP device code
|
||||
template <typename T>
|
||||
struct hip_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<bool> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int8_t> {
|
||||
using type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint8_t> {
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int16_t> {
|
||||
using type = int16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint16_t> {
|
||||
using type = uint16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int32_t> {
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint32_t> {
|
||||
using type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<int64_t> {
|
||||
using type = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<uint64_t> {
|
||||
using type = uint64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<float> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<double> {
|
||||
using type = double;
|
||||
};
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
template <>
|
||||
struct hip_type<__half> {
|
||||
using type = __half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hip_type<__hip_bfloat16> {
|
||||
using type = __hip_bfloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
using hip_type_t = typename hip_type<T>::type;
|
||||
|
||||
// Element-wise operations support
|
||||
template <typename T>
|
||||
constexpr bool is_floating_point_v = std::is_floating_point_v<T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_integral_v = std::is_integral_v<T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_signed_v = std::is_signed_v<T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_unsigned_v = std::is_unsigned_v<T>;
|
||||
|
||||
// Complex number helper functions
|
||||
inline __device__ hipFloatComplex make_complex(float real, float imag) {
|
||||
return make_hipFloatComplex(real, imag);
|
||||
}
|
||||
|
||||
inline __device__ float hip_real(hipFloatComplex z) {
|
||||
return hipCrealf(z);
|
||||
}
|
||||
|
||||
inline __device__ float hip_imag(hipFloatComplex z) {
|
||||
return hipCimagf(z);
|
||||
}
|
||||
|
||||
inline __device__ hipFloatComplex hip_conj(hipFloatComplex z) {
|
||||
return make_hipFloatComplex(hipCrealf(z), -hipCimagf(z));
|
||||
}
|
||||
|
||||
inline __device__ float hip_abs(hipFloatComplex z) {
|
||||
return sqrtf(hipCrealf(z) * hipCrealf(z) + hipCimagf(z) * hipCimagf(z));
|
||||
}
|
||||
|
||||
// Memory access utilities
|
||||
template <typename T>
|
||||
inline __device__ T hip_load_global(const T* ptr) {
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void hip_store_global(T* ptr, T value) {
|
||||
*ptr = value;
|
||||
}
|
||||
|
||||
// Grid and block utilities
|
||||
inline __device__ int hip_thread_idx() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_block_idx() {
|
||||
return blockIdx.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_block_dim() {
|
||||
return blockDim.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_grid_dim() {
|
||||
return gridDim.x;
|
||||
}
|
||||
|
||||
inline __device__ int hip_global_thread_idx() {
|
||||
return blockIdx.x * blockDim.x + threadIdx.x;
|
||||
}
|
||||
|
||||
// Synchronization
|
||||
inline __device__ void hip_sync_threads() {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Math constants for HIP (equivalent to CUDA's math_constants.h)
|
||||
#ifndef M_PI
|
||||
#define M_PI 3.14159265358979323846
|
||||
#endif
|
||||
|
||||
#ifndef M_LN2
|
||||
#define M_LN2 0.693147180559945309417
|
||||
#endif
|
||||
|
||||
#ifndef M_LN10
|
||||
#define M_LN10 2.302585092994045684018
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core::rocm
|
153
mlx/backend/rocm/iterators/general_iterator.hpp
Normal file
153
mlx/backend/rocm/iterators/general_iterator.hpp
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename IdxType>
|
||||
struct GeneralIterator {
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = IdxType;
|
||||
using pointer = IdxType*;
|
||||
using reference = IdxType&;
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
|
||||
const IdxType* base_ptr;
|
||||
IdxType offset;
|
||||
const int* shape;
|
||||
const size_t* strides;
|
||||
int ndim;
|
||||
size_t size;
|
||||
|
||||
__device__ GeneralIterator(
|
||||
const IdxType* base_ptr,
|
||||
IdxType offset,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim,
|
||||
size_t size)
|
||||
: base_ptr(base_ptr),
|
||||
offset(offset),
|
||||
shape(shape),
|
||||
strides(strides),
|
||||
ndim(ndim),
|
||||
size(size) {}
|
||||
|
||||
__device__ GeneralIterator operator+(difference_type n) const {
|
||||
return GeneralIterator(base_ptr, offset + n, shape, strides, ndim, size);
|
||||
}
|
||||
|
||||
__device__ GeneralIterator operator-(difference_type n) const {
|
||||
return GeneralIterator(base_ptr, offset - n, shape, strides, ndim, size);
|
||||
}
|
||||
|
||||
__device__ difference_type operator-(const GeneralIterator& other) const {
|
||||
return offset - other.offset;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator+=(difference_type n) {
|
||||
offset += n;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator-=(difference_type n) {
|
||||
offset -= n;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator++() {
|
||||
++offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator operator++(int) {
|
||||
GeneralIterator temp = *this;
|
||||
++offset;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator& operator--() {
|
||||
--offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ GeneralIterator operator--(int) {
|
||||
GeneralIterator temp = *this;
|
||||
--offset;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ bool operator==(const GeneralIterator& other) const {
|
||||
return offset == other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator!=(const GeneralIterator& other) const {
|
||||
return offset != other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator<(const GeneralIterator& other) const {
|
||||
return offset < other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator>(const GeneralIterator& other) const {
|
||||
return offset > other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator<=(const GeneralIterator& other) const {
|
||||
return offset <= other.offset;
|
||||
}
|
||||
|
||||
__device__ bool operator>=(const GeneralIterator& other) const {
|
||||
return offset >= other.offset;
|
||||
}
|
||||
|
||||
__device__ IdxType operator*() const {
|
||||
return base_ptr[elem_to_loc(offset, shape, strides, ndim)];
|
||||
}
|
||||
|
||||
__device__ IdxType operator[](difference_type n) const {
|
||||
return base_ptr[elem_to_loc(offset + n, shape, strides, ndim)];
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ size_t elem_to_loc(
|
||||
size_t elem,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim) const {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
auto q_and_r = div(elem, static_cast<size_t>(shape[i]));
|
||||
loc += q_and_r.rem * strides[i];
|
||||
elem = q_and_r.quot;
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
__device__ div_t div(size_t numer, size_t denom) const {
|
||||
div_t result;
|
||||
result.quot = numer / denom;
|
||||
result.rem = numer % denom;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename IdxType>
|
||||
__device__ std::pair<GeneralIterator<IdxType>, GeneralIterator<IdxType>>
|
||||
make_general_iterators(
|
||||
const IdxType* base_ptr,
|
||||
size_t size,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim) {
|
||||
auto begin =
|
||||
GeneralIterator<IdxType>(base_ptr, 0, shape, strides, ndim, size);
|
||||
auto end =
|
||||
GeneralIterator<IdxType>(base_ptr, size, shape, strides, ndim, size);
|
||||
return std::make_pair(begin, end);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
106
mlx/backend/rocm/iterators/strided_iterator.hpp
Normal file
106
mlx/backend/rocm/iterators/strided_iterator.hpp
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = T;
|
||||
using pointer = T*;
|
||||
using reference = T&;
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
|
||||
T* ptr;
|
||||
size_t stride;
|
||||
|
||||
__device__ StridedIterator(T* ptr, size_t stride)
|
||||
: ptr(ptr), stride(stride) {}
|
||||
|
||||
__device__ StridedIterator operator+(difference_type n) const {
|
||||
return StridedIterator(ptr + n * stride, stride);
|
||||
}
|
||||
|
||||
__device__ StridedIterator operator-(difference_type n) const {
|
||||
return StridedIterator(ptr - n * stride, stride);
|
||||
}
|
||||
|
||||
__device__ difference_type operator-(const StridedIterator& other) const {
|
||||
return (ptr - other.ptr) / stride;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator+=(difference_type n) {
|
||||
ptr += n * stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator-=(difference_type n) {
|
||||
ptr -= n * stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator++() {
|
||||
ptr += stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator operator++(int) {
|
||||
StridedIterator temp = *this;
|
||||
ptr += stride;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ StridedIterator& operator--() {
|
||||
ptr -= stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ StridedIterator operator--(int) {
|
||||
StridedIterator temp = *this;
|
||||
ptr -= stride;
|
||||
return temp;
|
||||
}
|
||||
|
||||
__device__ bool operator==(const StridedIterator& other) const {
|
||||
return ptr == other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator!=(const StridedIterator& other) const {
|
||||
return ptr != other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator<(const StridedIterator& other) const {
|
||||
return ptr < other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator>(const StridedIterator& other) const {
|
||||
return ptr > other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator<=(const StridedIterator& other) const {
|
||||
return ptr <= other.ptr;
|
||||
}
|
||||
|
||||
__device__ bool operator>=(const StridedIterator& other) const {
|
||||
return ptr >= other.ptr;
|
||||
}
|
||||
|
||||
__device__ T& operator*() const {
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
__device__ T& operator[](difference_type n) const {
|
||||
return *(ptr + n * stride);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ StridedIterator<T> make_strided_iterator(T* ptr, size_t stride) {
|
||||
return StridedIterator<T>(ptr, stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
@ -1,6 +1,406 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/iterators/strided_iterator.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/rocm/reduce/reduce.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
#include <rocprim/block/block_reduce.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
// Similar to rocprim::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, hip_plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* b,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride,
|
||||
int64_t b_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(rocprim::thread_reduce(xn, hip_plus<T>{}));
|
||||
}
|
||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T bn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF3::TempStorage f3;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(rocprim::thread_reduce(xn, hip_plus<T>{}));
|
||||
}
|
||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float3 factors = {};
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size, mean);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||
float meanwg = factors.x / axis_size;
|
||||
float meanwgxc = factors.y / axis_size;
|
||||
float normalizer2 = 1 / (factors.z / axis_size + eps);
|
||||
float normalizer = sqrt(normalizer2);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = gi * xi;
|
||||
}
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
rocprim::block_store_direct_blocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
template <typename T>
|
||||
struct hip_plus {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ int hip_ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline auto strided_iterator(const T* ptr, int64_t stride) {
|
||||
return ptr + stride; // Simplified strided iterator
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool LayerNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
auto [g, g_copied] = check_input(inputs[3]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
|
311
mlx/backend/rocm/reduce/col_reduce.hip
Normal file
311
mlx/backend/rocm/reduce/col_reduce.hip
Normal file
@ -0,0 +1,311 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/cast_op.hpp"
|
||||
#include "mlx/backend/rocm/reduce/reduce.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct ColReduceArgs {
|
||||
// The size of the contiguous column reduction.
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes (including last dimension).
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_col_reductions;
|
||||
|
||||
ColReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size();
|
||||
|
||||
non_col_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||
non_col_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int column =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
if (column * N_READS >= args.reduction_stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(
|
||||
block.thread_index().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
for (size_t r = block.thread_index().y;
|
||||
r < args.non_col_reductions * args.reduction_size;
|
||||
r += block.dim_threads().y) {
|
||||
U vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.reduction_stride,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(
|
||||
block.dim_threads().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do block reduce when each column has more than 1 element to reduce.
|
||||
if (block.dim_threads().y > 1) {
|
||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||
size_t col =
|
||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col * N_READS + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
if (block.thread_index().y == 0) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||
}
|
||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||
col = j * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (block.thread_index().y == 0) {
|
||||
rocprim::block_store_direct_blocked(
|
||||
column,
|
||||
out + out_idx * args.reduction_stride,
|
||||
totals,
|
||||
args.reduction_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void col_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
const ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
int r = block.thread_rank() / n_warps;
|
||||
int column = block.thread_rank() % n_warps;
|
||||
int in_offset = grid.block_index().x * BN;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||
U vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||
vals,
|
||||
args.reduction_stride - in_offset,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do warp reduce for each output.
|
||||
constexpr int n_outputs = BN / n_warps;
|
||||
static_assert(BM == 32 && n_outputs == N_READS);
|
||||
__shared__ U shared_vals[BM * BN];
|
||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
size_t out_offset = grid.block_index().x * BN;
|
||||
rocprim::block_store_direct_blocked(
|
||||
warp.meta_group_rank(),
|
||||
out + out_idx * args.reduction_stride + out_offset,
|
||||
totals,
|
||||
args.reduction_stride - out_offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions and templates
|
||||
template <int NDIM, bool USE_FAST_PATH>
|
||||
struct LoopedElemToLoc {
|
||||
size_t location;
|
||||
|
||||
__device__ LoopedElemToLoc(int reduce_ndim) : location(0) {}
|
||||
|
||||
__device__ void next(size_t step, const int* shape, const size_t* strides) {
|
||||
// Simplified implementation - actual would handle multi-dimensional indexing
|
||||
location += step;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T* make_cast_iterator(const T* ptr) {
|
||||
return const_cast<T*>(ptr);
|
||||
}
|
||||
|
||||
__device__ inline size_t elem_to_loc(
|
||||
size_t elem,
|
||||
const int* shape,
|
||||
const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
size_t q = elem / shape[i];
|
||||
size_t r = elem % shape[i];
|
||||
loc += r * strides[i];
|
||||
elem = q;
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const rocm::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
rocm::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
rocm::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = hip_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = rocm::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr int N_READS = 4;
|
||||
dim3 block_dims;
|
||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||
num_blocks.z = num_blocks.y;
|
||||
num_blocks.y = num_blocks.x;
|
||||
auto kernel =
|
||||
rocm::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (total < 32) {
|
||||
size_t stride_blocks =
|
||||
hip_ceil_div(args.reduction_stride, N_READS);
|
||||
block_dims.x = std::min(stride_blocks, 32ul);
|
||||
block_dims.y = std::min(total, 8ul);
|
||||
num_blocks.x = hip_ceil_div(stride_blocks, block_dims.x);
|
||||
} else {
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
block_dims.x = BM * BN / N_READS;
|
||||
num_blocks.x = hip_ceil_div(args.reduction_stride, BN);
|
||||
kernel = rocm::
|
||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||
}
|
||||
hipLaunchKernelGGL(kernel, num_blocks, block_dims, 0, stream,
|
||||
in.data<InType>(), out.data<OutType>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
119
mlx/backend/rocm/reduce/reduce.hpp
Normal file
119
mlx/backend/rocm/reduce/reduce.hpp
Normal file
@ -0,0 +1,119 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Reduction operation types
|
||||
template <typename Op, typename T>
|
||||
struct ReduceInit {
|
||||
static constexpr T value();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<struct Sum, T> {
|
||||
static constexpr T value() {
|
||||
return T(0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<struct Max, T> {
|
||||
static constexpr T value() {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<struct Min, T> {
|
||||
static constexpr T value() {
|
||||
return std::numeric_limits<T>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
// Reduction operations
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return fmax(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return fmin(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) const {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
// Utility functions for reductions
|
||||
template <typename T>
|
||||
__device__ T warp_reduce(T val, T (*op)(T, T)) {
|
||||
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||
val = op(val, __shfl_down(val, offset));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T block_reduce(T val, T (*op)(T, T)) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x % warpSize;
|
||||
int wid = threadIdx.x / warpSize;
|
||||
|
||||
val = warp_reduce(val, op);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
|
||||
if (wid == 0)
|
||||
val = warp_reduce(val, op);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
// Column reduction arguments
|
||||
struct ColReduceArgs {
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
int* shape;
|
||||
size_t* strides;
|
||||
int ndim;
|
||||
int* reduce_shape;
|
||||
size_t* reduce_strides;
|
||||
int reduce_ndim;
|
||||
size_t non_col_reductions;
|
||||
};
|
||||
|
||||
// Row reduction arguments
|
||||
struct RowReduceArgs {
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
int* shape;
|
||||
size_t* strides;
|
||||
int ndim;
|
||||
int* reduce_shape;
|
||||
size_t* reduce_strides;
|
||||
int reduce_ndim;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
@ -1,13 +1,375 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/iterators/strided_iterator.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/rocm/reduce/reduce.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
#include <rocprim/block/block_reduce.hpp>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
namespace mlx::core {
|
||||
|
||||
__global__ void rms_norm_kernel(float* input, float* output, int n) {
|
||||
// Placeholder implementation
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
(void)input; (void)output; (void)n; (void)idx;
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// Similar to rocprim::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, hip_plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum of squares.
|
||||
float sum_sq = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float val = static_cast<float>(xn[i]);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
sum_sq = BlockReduceT{block, temp}.Sum(sum_sq);
|
||||
|
||||
// RMS normalizer.
|
||||
float rms_normalizer = rsqrt(sum_sq / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = static_cast<float>(xn[i]) * rms_normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm);
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF2::TempStorage f2;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum of squares.
|
||||
float sum_sq = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float val = static_cast<float>(xn[i]);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
sum_sq = BlockReduceF{block, temp.f}.Sum(sum_sq);
|
||||
|
||||
// RMS normalizer.
|
||||
float rms_normalizer = rsqrt(sum_sq / axis_size + eps);
|
||||
|
||||
// Compute gradient terms.
|
||||
float2 factors = {};
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors.x += wg;
|
||||
factors.y += wg * xi;
|
||||
}
|
||||
}
|
||||
auto plus_f2 = [] __device__ (const float2& a, const float2& b) -> float2 {
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
};
|
||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||
float mean_wg = factors.x / axis_size;
|
||||
float mean_wgx = factors.y / axis_size;
|
||||
float rms3 = rms_normalizer * rms_normalizer * rms_normalizer;
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, x, xn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, g, gn, axis_size);
|
||||
rocprim::block_load_direct_blocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float norm = xi * rms_normalizer;
|
||||
xn[i] = rms_normalizer * (wi * gi - mean_wg) - norm * mean_wgx * rms3;
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = gi * norm;
|
||||
}
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
rocprim::block_store_direct_blocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
template <typename T>
|
||||
struct hip_plus {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ int hip_ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline auto strided_iterator(const T* ptr, int64_t stride) {
|
||||
return ptr + stride; // Simplified strided iterator
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rmsnorm", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
auto [g, g_copied] = check_input(inputs[2]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rmsnorm_vjp", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
@ -1,13 +1,383 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
namespace mlx::core {
|
||||
|
||||
__global__ void rope_kernel(float* input, float* output, int n) {
|
||||
// Placeholder for RoPE implementation
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
(void)input; (void)output; (void)n; (void)idx;
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__device__ void rope_single_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int32_t offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 pos,
|
||||
uint2 dims) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint index_1, index_2;
|
||||
if (traditional) {
|
||||
index_1 = 2 * pos.x + pos.y * stride;
|
||||
index_2 = index_1 + 1;
|
||||
} else {
|
||||
index_1 = pos.x + pos.y * stride;
|
||||
index_2 = index_1 + dims.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[index_1]);
|
||||
float x2 = static_cast<float>(in[index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[index_1] = static_cast<T>(rx1);
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
int64_t stride,
|
||||
uint2 dims) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 dims,
|
||||
int64_t freq_stride) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const hip_array<int64_t, 3> strides,
|
||||
const hip_array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
size_t in_index_1, in_index_2;
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
in_index_1 += strides[0];
|
||||
in_index_2 += strides[0];
|
||||
out_index_1 += out_strides[0];
|
||||
out_index_2 += out_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
const hip_array<int64_t, 3> strides,
|
||||
const hip_array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
float base,
|
||||
const hip_array<int64_t, 3> strides,
|
||||
const hip_array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RoPE::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
hip_array<int64_t, 3> strides;
|
||||
hip_array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(in, out, ctype, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
|
||||
// Either copy or apply in-place
|
||||
else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
out_strides[0] = mat_size;
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(donated ? out : in);
|
||||
encoder.set_input_array(offset);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
||||
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = rocm::rope_single<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel = rocm::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = rocm::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = rocm::rope<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
hipLaunchKernelGGL(kernel, grid, block, 0, stream,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
@ -1,22 +1,179 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/cast_op.hpp"
|
||||
#include "mlx/backend/rocm/device/fp16_math.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <rocprim/block/block_load.hpp>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
#include <cassert>
|
||||
|
||||
__global__ void softmax_kernel(float* input, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (idx < n) {
|
||||
// Simplified softmax placeholder - real implementation needs reduction
|
||||
output[idx] = expf(input[idx]);
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = -INFINITY;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
-INFINITY);
|
||||
prevmax = maxval;
|
||||
maxval = fmax(maxval, rocprim::thread_reduce(vals, hip_max<AccT>()));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, hip_max<AccT>());
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, hip_plus<AccT>());
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: -INFINITY;
|
||||
maxval = cg::reduce(warp, maxval, hip_max<AccT>());
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, hip_plus<AccT>());
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Write output.
|
||||
for (int r = 0; r < hip_ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T vals[N_READS];
|
||||
rocprim::block_load_direct_blocked(index, in, vals, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||
}
|
||||
rocprim::block_store_direct_blocked(index, out, vals, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
void launch_softmax(float* input, float* output, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(softmax_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
|
||||
// Utility functions for ROCm
|
||||
template <typename T>
|
||||
struct hip_max {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return fmax(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct hip_plus {
|
||||
__device__ T operator()(const T& a, const T& b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ int hip_ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
||||
template <typename T>
|
||||
__device__ inline T* make_cast_iterator(const T* ptr) {
|
||||
return const_cast<T*>(ptr);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
array in = set_output(inputs[0]);
|
||||
bool precise = in.dtype() != float32 && precise_;
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||
using DataType = hip_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(rocm::hip_ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = rocm::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||
if (precise) {
|
||||
kernel = rocm::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||
}
|
||||
hipLaunchKernelGGL(kernel, n_rows, BLOCK_DIM, 0, stream,
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1 +1,178 @@
|
||||
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocthrust/device_ptr.h>
|
||||
#include <rocthrust/transform.h>
|
||||
#include <rocprim/device/device_segmented_sort.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct ModOp {
|
||||
T divisor;
|
||||
__device__ T operator()(T x) {
|
||||
return x % divisor;
|
||||
}
|
||||
};
|
||||
|
||||
// We can not use any op in eval, make an utility.
|
||||
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
||||
std::vector<int> axes(in.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[axis1], axes[axis2]);
|
||||
// TODO: Share the code with Transpose::eval.
|
||||
Shape shape(axes.size());
|
||||
Strides strides(in.ndim());
|
||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||
shape[ax] = in.shape()[axes[ax]];
|
||||
strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous) {
|
||||
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
}
|
||||
array out(shape, in.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort_pairs(rocm::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_HIP_ERROR(
|
||||
rocprim::segmented_sort_pairs(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_HIP_ERROR(rocprim::segmented_sort_pairs(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort(rocm::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_HIP_ERROR(
|
||||
rocprim::segmented_sort_keys(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_HIP_ERROR(rocprim::segmented_sort_keys(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array out = out_;
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
int nsort = in.shape(axis);
|
||||
int nsegments = in.data_size() / nsort;
|
||||
int last_dim = in.ndim() - 1;
|
||||
|
||||
// If we are not sorting the innermost dimension of a contiguous array,
|
||||
// transpose and make a copy.
|
||||
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
||||
if (!is_segmented_sort) {
|
||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||
in = array(trans.shape(), trans.dtype(), nullptr, {});
|
||||
copy_gpu(trans, in, CopyType::General, s);
|
||||
encoder.add_temporary(in);
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
using Type = hip_type_t<CTYPE>;
|
||||
auto offsets = rocthrust::make_transform_iterator(
|
||||
rocthrust::make_counting_iterator(0),
|
||||
[nsort] __device__(int i) { return i * nsort; });
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(
|
||||
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(indices);
|
||||
rocthrust::transform(
|
||||
rocm::thrust_policy(stream),
|
||||
rocthrust::counting_iterator<uint32_t>(0),
|
||||
rocthrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||
rocthrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
|
||||
// In argsort though we don't need the result of sorted values, the
|
||||
// API requires us to provide an array to store it.
|
||||
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
segmented_sort_pairs(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
} else {
|
||||
segmented_sort(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"ROCm backend does not support sorting complex numbers");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (!is_segmented_sort) {
|
||||
// Swap the sorted axis back.
|
||||
// TODO: Do in-place transpose instead of using a temporary out array.
|
||||
copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,8 +1,136 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/ternary_ops.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocthrust/device_ptr.h>
|
||||
#include <rocthrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
template <typename Op, typename Condition, typename A, typename B, typename Out>
|
||||
constexpr bool supports_ternary_op() {
|
||||
if (std::is_same_v<Op, Select>) {
|
||||
return std::is_same_v<Condition, bool> && std::is_same_v<A, Out> && std::is_same_v<B, Out>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
template <typename Op>
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& condition = inputs[0];
|
||||
auto& a = inputs[1];
|
||||
auto& b = inputs[2];
|
||||
|
||||
if (condition.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(condition);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(condition.dtype(), CONDITION_TYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), A_TYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(b.dtype(), B_TYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), OUT_TYPE, {
|
||||
if constexpr (rocm::supports_ternary_op<Op, CONDITION_TYPE, A_TYPE, B_TYPE, OUT_TYPE>()) {
|
||||
using ConditionType = hip_type_t<CONDITION_TYPE>;
|
||||
using AType = hip_type_t<A_TYPE>;
|
||||
using BType = hip_type_t<B_TYPE>;
|
||||
using OutType = hip_type_t<OUT_TYPE>;
|
||||
|
||||
auto policy = rocm::thrust_policy(stream);
|
||||
auto condition_ptr = rocthrust::device_pointer_cast(condition.data<ConditionType>());
|
||||
auto a_ptr = rocthrust::device_pointer_cast(a.data<AType>());
|
||||
auto b_ptr = rocthrust::device_pointer_cast(b.data<BType>());
|
||||
auto out_ptr = rocthrust::device_pointer_cast(out.data<OutType>());
|
||||
|
||||
if (condition.flags().contiguous && a.flags().contiguous && b.flags().contiguous) {
|
||||
auto ternary_op = [=] __device__ (const auto& tuple) -> OutType {
|
||||
return Op{}(rocthrust::get<0>(tuple), rocthrust::get<1>(tuple), rocthrust::get<2>(tuple));
|
||||
};
|
||||
|
||||
auto zip_begin = rocthrust::make_zip_iterator(
|
||||
rocthrust::make_tuple(condition_ptr, a_ptr, b_ptr));
|
||||
auto zip_end = rocthrust::make_zip_iterator(
|
||||
rocthrust::make_tuple(condition_ptr + condition.data_size(),
|
||||
a_ptr + a.data_size(),
|
||||
b_ptr + b.data_size()));
|
||||
|
||||
rocthrust::transform(policy, zip_begin, zip_end, out_ptr, ternary_op);
|
||||
} else {
|
||||
// Handle non-contiguous arrays with general iterators
|
||||
auto [condition_shape, condition_strides] = collapse_contiguous_dims(condition);
|
||||
auto [a_shape, a_strides] = collapse_contiguous_dims(a);
|
||||
auto [b_shape, b_strides] = collapse_contiguous_dims(b);
|
||||
|
||||
auto [condition_begin, condition_end] = rocm::make_general_iterators<int64_t>(
|
||||
condition_ptr, condition.size(), condition_shape, condition_strides);
|
||||
auto [a_begin, a_end] = rocm::make_general_iterators<int64_t>(
|
||||
a_ptr, a.size(), a_shape, a_strides);
|
||||
auto [b_begin, b_end] = rocm::make_general_iterators<int64_t>(
|
||||
b_ptr, b.size(), b_shape, b_strides);
|
||||
|
||||
auto ternary_op = [=] __device__ (const auto& tuple) -> OutType {
|
||||
return Op{}(rocthrust::get<0>(tuple), rocthrust::get<1>(tuple), rocthrust::get<2>(tuple));
|
||||
};
|
||||
|
||||
auto zip_begin = rocthrust::make_zip_iterator(
|
||||
rocthrust::make_tuple(condition_begin, a_begin, b_begin));
|
||||
auto zip_end = rocthrust::make_zip_iterator(
|
||||
rocthrust::make_tuple(condition_end, a_end, b_end));
|
||||
|
||||
rocthrust::transform(policy, zip_begin, zip_end, out_ptr, ternary_op);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do ternary op {} on inputs of {}, {}, {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(condition.dtype()),
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(b.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
set_ternary_output_data(inputs, out);
|
||||
ternary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu<rocm::Select>(inputs, out, get_primitive_string(this), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
__global__ void select_kernel(float* condition, float* a, float* b, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
@ -1,8 +1,197 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/device/hip_complex_math.hpp"
|
||||
#include "mlx/backend/rocm/device/unary_ops.hpp"
|
||||
#include "mlx/backend/rocm/iterators/general_iterator.hpp"
|
||||
#include "mlx/backend/rocm/kernel_utils.hpp"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocthrust/device_ptr.h>
|
||||
#include <rocthrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
std::is_same_v<Op, Sign>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
|
||||
std::is_same_v<Op, Square>) {
|
||||
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Conjugate>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
|
||||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
|
||||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
|
||||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
}
|
||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = rocm::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](hipStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (rocm::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = hip_type_t<CTYPE_IN>;
|
||||
using OutType = hip_type_t<CTYPE_OUT>;
|
||||
auto policy = rocm::thrust_policy(stream);
|
||||
auto in_ptr = rocthrust::device_pointer_cast(in.data<InType>());
|
||||
auto out_ptr = rocthrust::device_pointer_cast(out.data<OutType>());
|
||||
if (in.flags().contiguous) {
|
||||
rocthrust::transform(
|
||||
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = rocm::make_general_iterators<int64_t>(
|
||||
in_ptr, in.size(), shape, strides);
|
||||
rocthrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do unary op {} on input of {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
auto& s = out.primitive().stream(); \
|
||||
unary_op_gpu<rocm::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
UNARY_GPU(Abs)
|
||||
UNARY_GPU(ArcCos)
|
||||
UNARY_GPU(ArcCosh)
|
||||
UNARY_GPU(ArcSin)
|
||||
UNARY_GPU(ArcSinh)
|
||||
UNARY_GPU(ArcTan)
|
||||
UNARY_GPU(ArcTanh)
|
||||
UNARY_GPU(BitwiseInvert)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Conjugate)
|
||||
UNARY_GPU(Cos)
|
||||
UNARY_GPU(Cosh)
|
||||
UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Imag)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Real)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
UNARY_GPU(Sinh)
|
||||
UNARY_GPU(Square)
|
||||
UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu<rocm::Log>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu<rocm::Log2>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu<rocm::Log10>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu<rocm::Round>(inputs, out, get_primitive_string(this), s);
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
if (recip_) {
|
||||
unary_op_gpu<rocm::Rsqrt>(inputs, out, "Rsqrt", s);
|
||||
} else {
|
||||
unary_op_gpu<rocm::Sqrt>(inputs, out, "Sqrt", s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
__global__ void relu_kernel(float* input, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
Loading…
Reference in New Issue
Block a user