mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
36 lines
820 B
C++
36 lines
820 B
C++
![]() |
// Copyright © 2025 Apple Inc.
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include <hip/hip_runtime.h>
|
||
|
|
||
|
namespace mlx::core::rocm {
|
||
|
|
||
|
// Atomic operations for HIP
|
||
|
__device__ inline float atomicAddFloat(float* address, float val) {
|
||
|
return atomicAdd(address, val);
|
||
|
}
|
||
|
|
||
|
__device__ inline double atomicAddDouble(double* address, double val) {
|
||
|
return atomicAdd(address, val);
|
||
|
}
|
||
|
|
||
|
__device__ inline int atomicAddInt(int* address, int val) {
|
||
|
return atomicAdd(address, val);
|
||
|
}
|
||
|
|
||
|
__device__ inline unsigned int atomicAddUInt(
|
||
|
unsigned int* address,
|
||
|
unsigned int val) {
|
||
|
return atomicAdd(address, val);
|
||
|
}
|
||
|
|
||
|
__device__ inline float atomicMaxFloat(float* address, float val) {
|
||
|
return atomicMax(address, val);
|
||
|
}
|
||
|
|
||
|
__device__ inline float atomicMinFloat(float* address, float val) {
|
||
|
return atomicMin(address, val);
|
||
|
}
|
||
|
|
||
|
} // namespace mlx::core::rocm
|