mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation with CUDA 11 (#2331)
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
||||
|
||||
@@ -40,15 +42,15 @@ struct Sum {
|
||||
}
|
||||
|
||||
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
||||
atomicAdd(x, y);
|
||||
atomic_add(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(int* x, int y) {
|
||||
atomicAdd(x, y);
|
||||
atomic_add(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(float* x, float y) {
|
||||
atomicAdd(x, y);
|
||||
atomic_add(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -152,7 +154,7 @@ struct ReduceInit<Sum, T> {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{0, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Sum, T>::type{0};
|
||||
return cast_to<typename ReduceResult<Sum, T>::type>(0);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -163,7 +165,7 @@ struct ReduceInit<Prod, T> {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
return cast_to<typename ReduceResult<Prod, T>::type>(1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user