mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Make sure softmax doesn't change the actual max
This commit is contained in:
parent
818e8e663e
commit
fd1d0821d2
@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
make_cast_iterator<AccT>(in),
|
make_cast_iterator<AccT>(in),
|
||||||
vals,
|
vals,
|
||||||
axis_size,
|
axis_size,
|
||||||
Limits<AccT>::finite_min());
|
Limits<AccT>::min());
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
// Online normalizer calculation for softmax:
|
// Online normalizer calculation for softmax:
|
||||||
@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
block.sync();
|
block.sync();
|
||||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
? local_max[warp.thread_rank()]
|
? local_max[warp.thread_rank()]
|
||||||
: Limits<AccT>::finite_min();
|
: Limits<AccT>::min();
|
||||||
maxval = cg::reduce(warp, maxval, max_op);
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
@ -13,7 +13,6 @@ cuda_skip = {
|
|||||||
"TestLayers.test_upsample",
|
"TestLayers.test_upsample",
|
||||||
"TestOps.test_complex_ops",
|
"TestOps.test_complex_ops",
|
||||||
"TestOps.test_dynamic_slicing",
|
"TestOps.test_dynamic_slicing",
|
||||||
"TestOps.test_softmax",
|
|
||||||
"TestReduce.test_axis_permutation_sums",
|
"TestReduce.test_axis_permutation_sums",
|
||||||
"TestReduce.test_dtypes",
|
"TestReduce.test_dtypes",
|
||||||
"TestReduce.test_expand_sums",
|
"TestReduce.test_expand_sums",
|
||||||
|
Loading…
Reference in New Issue
Block a user