mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Remove redundant simd_sum in logsumexp (#2210)
This commit is contained in:
parent
35c87741cf
commit
7774b87cbd
@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||||
|
|
||||||
if (simd_group_id == 0) {
|
if (lid == 0) {
|
||||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user