fix softmax / logsumexp (#2042)

This commit is contained in:
Awni Hannun 2025-04-03 08:32:59 -07:00 committed by GitHub
parent 9ba81e3da4
commit c41f7565ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 0 deletions

View File

@ -33,6 +33,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
local_max[simd_lane_id] = Limits<AccT>::min; local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0; local_normalizer[simd_lane_id] = 0;
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max // Get the max
AccT maxval = Limits<AccT>::finite_min; AccT maxval = Limits<AccT>::finite_min;

View File

@ -40,6 +40,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
local_max[simd_lane_id] = Limits<AccT>::min; local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0; local_normalizer[simd_lane_id] = 0;
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max // Get the max
AccT maxval = Limits<AccT>::finite_min; AccT maxval = Limits<AccT>::finite_min;