Fix formatting

This commit is contained in:
Angelos Katharopoulos 2025-06-06 01:09:31 -07:00
parent 7734bc5c4f
commit 570dd8287a

View File

@ -15,7 +15,7 @@ inline void initialize_buffer(
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
if (simd_group_id == 0) { if (simd_group_id == 0) {
for (int i=0; i<N; i++) { for (int i = 0; i < N; i++) {
xs[N * simd_lane_id + i] = 0; xs[N * simd_lane_id + i] = 0;
} }
} }
@ -28,16 +28,16 @@ inline void threadgroup_sum(
threadgroup float* xs, threadgroup float* xs,
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
for (int i=0; i<N; i++) { for (int i = 0; i < N; i++) {
x[i] = simd_sum(x[i]); x[i] = simd_sum(x[i]);
} }
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
for (int i=0; i<N; i++) { for (int i = 0; i < N; i++) {
xs[N * simd_group_id + i] = x[i]; xs[N * simd_group_id + i] = x[i];
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i=0; i<N; i++) { for (int i = 0; i < N; i++) {
x[i] = xs[N * simd_lane_id + i]; x[i] = xs[N * simd_lane_id + i];
x[i] = simd_sum(x[i]); x[i] = simd_sum(x[i]);
} }