Make sure softmax doesn't change the actual max

This commit is contained in:
Angelos Katharopoulos 2025-06-22 23:34:32 -07:00
parent 818e8e663e
commit fd1d0821d2
2 changed files with 2 additions and 3 deletions

View File

@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
make_cast_iterator<AccT>(in), make_cast_iterator<AccT>(in),
vals, vals,
axis_size, axis_size,
Limits<AccT>::finite_min()); Limits<AccT>::min());
prevmax = maxval; prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax: // Online normalizer calculation for softmax:
@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
block.sync(); block.sync();
maxval = warp.thread_rank() < warp.meta_group_size() maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()] ? local_max[warp.thread_rank()]
: Limits<AccT>::finite_min(); : Limits<AccT>::min();
maxval = cg::reduce(warp, maxval, max_op); maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {

View File

@ -13,7 +13,6 @@ cuda_skip = {
"TestLayers.test_upsample", "TestLayers.test_upsample",
"TestOps.test_complex_ops", "TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing", "TestOps.test_dynamic_slicing",
"TestOps.test_softmax",
"TestReduce.test_axis_permutation_sums", "TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes", "TestReduce.test_dtypes",
"TestReduce.test_expand_sums", "TestReduce.test_expand_sums",