14 constant
int& axis_size,
15 uint gid [[threadgroup_position_in_grid]],
16 uint _lid [[thread_position_in_threadgroup]],
17 uint simd_lane_id [[thread_index_in_simdgroup]],
18 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
21 constexpr int SIMD_SIZE = 32;
23 threadgroup AccT local_max[SIMD_SIZE];
24 threadgroup AccT local_normalizer[SIMD_SIZE];
28 in += gid * axis_size + lid * N_READS;
29 if (lid * N_READS + N_READS <= axis_size) {
30 for (
int i = 0; i < N_READS; i++) {
34 for (
int i = 0; i < N_READS; i++) {
35 ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
39 if (simd_group_id == 0) {
41 local_normalizer[simd_lane_id] = 0;
43 threadgroup_barrier(mem_flags::mem_threadgroup);
47 for (
int i = 0; i < N_READS; i++) {
48 maxval = (maxval < ld[i]) ? ld[i] : maxval;
50 maxval = simd_max(maxval);
51 if (simd_lane_id == 0) {
52 local_max[simd_group_id] = maxval;
54 threadgroup_barrier(mem_flags::mem_threadgroup);
55 if (simd_group_id == 0) {
56 maxval = simd_max(local_max[simd_lane_id]);
57 if (simd_lane_id == 0) {
58 local_max[0] = maxval;
61 threadgroup_barrier(mem_flags::mem_threadgroup);
62 maxval = local_max[0];
66 for (
int i = 0; i < N_READS; i++) {
71 normalizer = simd_sum(normalizer);
72 if (simd_lane_id == 0) {
73 local_normalizer[simd_group_id] = normalizer;
75 threadgroup_barrier(mem_flags::mem_threadgroup);
76 if (simd_group_id == 0) {
77 normalizer = simd_sum(local_normalizer[simd_lane_id]);
78 if (simd_lane_id == 0) {
79 local_normalizer[0] = normalizer;
82 threadgroup_barrier(mem_flags::mem_threadgroup);
83 normalizer = 1 / local_normalizer[0];
86 out += gid * axis_size + lid * N_READS;
87 if (lid * N_READS + N_READS <= axis_size) {
88 for (
int i = 0; i < N_READS; i++) {
89 out[i] = T(ld[i] * normalizer);
92 for (
int i = 0; i < N_READS; i++) {
93 if ((lid * N_READS + i) < axis_size) {
94 out[i] = T(ld[i] * normalizer);
104 constant
int& axis_size,
105 uint gid [[threadgroup_position_in_grid]],
106 uint lid [[thread_position_in_threadgroup]],
107 uint lsize [[threads_per_threadgroup]],
108 uint simd_lane_id [[thread_index_in_simdgroup]],
109 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
110 in += gid * axis_size;
112 constexpr int SIMD_SIZE = 32;
114 threadgroup AccT local_max[SIMD_SIZE];
115 threadgroup AccT local_normalizer[SIMD_SIZE];
121 for (
int r = 0; r < static_cast<int>(
ceildiv(axis_size, N_READS * lsize));
123 int offset = r * lsize * N_READS + lid * N_READS;
125 if (offset + N_READS <= axis_size) {
126 for (
int i = 0; i < N_READS; i++) {
127 vals[i] = AccT(in[offset + i]);
130 for (
int i = 0; i < N_READS; i++) {
131 vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
136 for (
int i = 0; i < N_READS; i++) {
137 maxval = (maxval < vals[i]) ? vals[i] : maxval;
140 for (
int i = 0; i < N_READS; i++) {
151 maxval = simd_max(maxval);
153 normalizer = simd_sum(normalizer);
158 if (simd_lane_id == 0) {
159 local_max[simd_group_id] = maxval;
161 threadgroup_barrier(mem_flags::mem_threadgroup);
162 maxval = simd_max(local_max[simd_lane_id]);
164 if (simd_lane_id == 0) {
165 local_normalizer[simd_group_id] = normalizer;
167 threadgroup_barrier(mem_flags::mem_threadgroup);
168 normalizer = simd_sum(local_normalizer[simd_lane_id]);
169 normalizer = 1 / normalizer;
173 out += gid * axis_size;
174 for (
int r = 0; r < static_cast<int>(
ceildiv(axis_size, N_READS * lsize));
176 int offset = r * lsize * N_READS + lid * N_READS;
177 if (offset + N_READS <= axis_size) {
178 for (
int i = 0; i < N_READS; i++) {
179 out[offset + i] = T(
softmax_exp(in[offset + i] - maxval) * normalizer);
182 for (
int i = 0; i < N_READS; i++) {
183 if (offset + i < axis_size) {
185 T(
softmax_exp(in[offset + i] - maxval) * normalizer);
void softmax_single_row(const device T *in, device T *out, constant int &axis_size, uint gid, uint _lid, uint simd_lane_id, uint simd_group_id)
Definition softmax.h:11
void softmax_looped(const device T *in, device T *out, constant int &axis_size, uint gid, uint lid, uint lsize, uint simd_lane_id, uint simd_group_id)
Definition softmax.h:101