mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix softmax / logsumexp (#2042)
This commit is contained in:
parent
9ba81e3da4
commit
c41f7565ed
@ -33,6 +33,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
local_max[simd_lane_id] = Limits<AccT>::min;
|
local_max[simd_lane_id] = Limits<AccT>::min;
|
||||||
local_normalizer[simd_lane_id] = 0;
|
local_normalizer[simd_lane_id] = 0;
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Get the max
|
// Get the max
|
||||||
AccT maxval = Limits<AccT>::finite_min;
|
AccT maxval = Limits<AccT>::finite_min;
|
||||||
|
@ -40,6 +40,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
local_max[simd_lane_id] = Limits<AccT>::min;
|
local_max[simd_lane_id] = Limits<AccT>::min;
|
||||||
local_normalizer[simd_lane_id] = 0;
|
local_normalizer[simd_lane_id] = 0;
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Get the max
|
// Get the max
|
||||||
AccT maxval = Limits<AccT>::finite_min;
|
AccT maxval = Limits<AccT>::finite_min;
|
||||||
|
Loading…
Reference in New Issue
Block a user