mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add cuda sm 90
This commit is contained in:
@@ -108,7 +108,7 @@ endif()
|
|||||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||||
set(MLX_CUDA_ARCHITECTURES
|
set(MLX_CUDA_ARCHITECTURES
|
||||||
"70;80"
|
"70;80;90"
|
||||||
CACHE STRING "CUDA architectures")
|
CACHE STRING "CUDA architectures")
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
|
|||||||
@@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
||||||
#if __CUDA_ARCH__ < 900
|
|
||||||
atomic_add_general(out, val);
|
atomic_add_general(out, val);
|
||||||
#else
|
|
||||||
atomicAdd(out, val);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||||
|
|||||||
Reference in New Issue
Block a user