mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Fix formatting
This commit is contained in:
parent
7734bc5c4f
commit
570dd8287a
@ -15,7 +15,7 @@ inline void initialize_buffer(
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
if (simd_group_id == 0) {
|
||||
for (int i=0; i<N; i++) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
xs[N * simd_lane_id + i] = 0;
|
||||
}
|
||||
}
|
||||
@ -28,16 +28,16 @@ inline void threadgroup_sum(
|
||||
threadgroup float* xs,
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
for (int i=0; i<N; i++) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
x[i] = simd_sum(x[i]);
|
||||
}
|
||||
if (simd_lane_id == 0) {
|
||||
for (int i=0; i<N; i++) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
xs[N * simd_group_id + i] = x[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i=0; i<N; i++) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
x[i] = xs[N * simd_lane_id + i];
|
||||
x[i] = simd_sum(x[i]);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user