mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
Fix layernorm race condition (#2340)
This commit is contained in:
parent
0e0d9ac522
commit
f5299f72cd
@ -31,6 +31,7 @@ inline void threadgroup_sum(
|
|||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
x[i] = simd_sum(x[i]);
|
x[i] = simd_sum(x[i]);
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
xs[N * simd_group_id + i] = x[i];
|
xs[N * simd_group_id + i] = x[i];
|
||||||
|
Loading…
Reference in New Issue
Block a user