mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation with CUDA 11 (#2331)
This commit is contained in:
@@ -43,7 +43,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
AccT normalizer = cast_to<AccT>(0);
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
|
||||
Reference in New Issue
Block a user