MLX
Loading...
Searching...
No Matches
softmax.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3template <typename T>
4inline T softmax_exp(T x) {
5 // Softmax doesn't need high precision exponential cause x is gonna be in
6 // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
7 return fast::exp(x);
8}
9
10template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
11[[kernel]] void softmax_single_row(
12 const device T* in,
13 device T* out,
14 constant int& axis_size,
15 uint gid [[threadgroup_position_in_grid]],
16 uint _lid [[thread_position_in_threadgroup]],
17 uint simd_lane_id [[thread_index_in_simdgroup]],
18 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
19 int lid = _lid;
20
21 constexpr int SIMD_SIZE = 32;
22
23 threadgroup AccT local_max[SIMD_SIZE];
24 threadgroup AccT local_normalizer[SIMD_SIZE];
25
26 AccT ld[N_READS];
27
28 in += gid * size_t(axis_size) + lid * N_READS;
29 if (lid * N_READS + N_READS <= axis_size) {
30 for (int i = 0; i < N_READS; i++) {
31 ld[i] = AccT(in[i]);
32 }
33 } else {
34 for (int i = 0; i < N_READS; i++) {
35 ld[i] =
36 ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
37 }
38 }
39 if (simd_group_id == 0) {
40 local_max[simd_lane_id] = Limits<AccT>::min;
41 local_normalizer[simd_lane_id] = 0;
42 }
43 threadgroup_barrier(mem_flags::mem_threadgroup);
44
45 // Get the max
46 AccT maxval = Limits<AccT>::finite_min;
47 for (int i = 0; i < N_READS; i++) {
48 maxval = (maxval < ld[i]) ? ld[i] : maxval;
49 }
50 maxval = simd_max(maxval);
51 if (simd_lane_id == 0) {
52 local_max[simd_group_id] = maxval;
53 }
54 threadgroup_barrier(mem_flags::mem_threadgroup);
55 if (simd_group_id == 0) {
56 maxval = simd_max(local_max[simd_lane_id]);
57 if (simd_lane_id == 0) {
58 local_max[0] = maxval;
59 }
60 }
61 threadgroup_barrier(mem_flags::mem_threadgroup);
62 maxval = local_max[0];
63
64 // Compute exp(x_i - maxval) and store the partial sums in local_normalizer
65 AccT normalizer = 0;
66 for (int i = 0; i < N_READS; i++) {
67 AccT exp_x = softmax_exp(ld[i] - maxval);
68 ld[i] = exp_x;
69 normalizer += exp_x;
70 }
71 normalizer = simd_sum(normalizer);
72 if (simd_lane_id == 0) {
73 local_normalizer[simd_group_id] = normalizer;
74 }
75 threadgroup_barrier(mem_flags::mem_threadgroup);
76 if (simd_group_id == 0) {
77 normalizer = simd_sum(local_normalizer[simd_lane_id]);
78 if (simd_lane_id == 0) {
79 local_normalizer[0] = normalizer;
80 }
81 }
82 threadgroup_barrier(mem_flags::mem_threadgroup);
83 normalizer = 1 / local_normalizer[0];
84
85 // Normalize and write to the output
86 out += gid * size_t(axis_size) + lid * N_READS;
87 if (lid * N_READS + N_READS <= axis_size) {
88 for (int i = 0; i < N_READS; i++) {
89 out[i] = T(ld[i] * normalizer);
90 }
91 } else {
92 for (int i = 0; i < N_READS; i++) {
93 if ((lid * N_READS + i) < axis_size) {
94 out[i] = T(ld[i] * normalizer);
95 }
96 }
97 }
98}
99
100template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
101[[kernel]] void softmax_looped(
102 const device T* in,
103 device T* out,
104 constant int& axis_size,
105 uint gid [[threadgroup_position_in_grid]],
106 uint lid [[thread_position_in_threadgroup]],
107 uint lsize [[threads_per_threadgroup]],
108 uint simd_lane_id [[thread_index_in_simdgroup]],
109 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
110 in += gid * size_t(axis_size);
111
112 constexpr int SIMD_SIZE = 32;
113
114 threadgroup AccT local_max[SIMD_SIZE];
115 threadgroup AccT local_normalizer[SIMD_SIZE];
116
117 // Get the max and the normalizer in one go
118 AccT prevmax;
119 AccT maxval = Limits<AccT>::finite_min;
120 AccT normalizer = 0;
121 for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
122 r++) {
123 int offset = r * lsize * N_READS + lid * N_READS;
124 AccT vals[N_READS];
125 if (offset + N_READS <= axis_size) {
126 for (int i = 0; i < N_READS; i++) {
127 vals[i] = AccT(in[offset + i]);
128 }
129 } else {
130 for (int i = 0; i < N_READS; i++) {
131 vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
133 }
134 }
135 prevmax = maxval;
136 for (int i = 0; i < N_READS; i++) {
137 maxval = (maxval < vals[i]) ? vals[i] : maxval;
138 }
139 normalizer *= softmax_exp(prevmax - maxval);
140 for (int i = 0; i < N_READS; i++) {
141 normalizer += softmax_exp(vals[i] - maxval);
142 }
143 }
144 // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
145 // lsize) parts. We need to combine them.
146 // 1. We start by finding the max across simd groups
147 // 2. We then change the partial normalizers to account for a possible
148 // change in max
149 // 3. We sum all normalizers
150 prevmax = maxval;
151 maxval = simd_max(maxval);
152 normalizer *= softmax_exp(prevmax - maxval);
153 normalizer = simd_sum(normalizer);
154
155 // Now the normalizer and max value is correct for each simdgroup. We write
156 // them shared memory and combine them.
157 prevmax = maxval;
158 if (simd_lane_id == 0) {
159 local_max[simd_group_id] = maxval;
160 }
161 threadgroup_barrier(mem_flags::mem_threadgroup);
162 maxval = simd_max(local_max[simd_lane_id]);
163 normalizer *= softmax_exp(prevmax - maxval);
164 if (simd_lane_id == 0) {
165 local_normalizer[simd_group_id] = normalizer;
166 }
167 threadgroup_barrier(mem_flags::mem_threadgroup);
168 normalizer = simd_sum(local_normalizer[simd_lane_id]);
169 normalizer = 1 / normalizer;
170
171 // Finally given the normalizer and max value we can directly write the
172 // softmax output
173 out += gid * size_t(axis_size);
174 for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
175 r++) {
176 int offset = r * lsize * N_READS + lid * N_READS;
177 if (offset + N_READS <= axis_size) {
178 for (int i = 0; i < N_READS; i++) {
179 out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
180 }
181 } else {
182 for (int i = 0; i < N_READS; i++) {
183 if (offset + i < axis_size) {
184 out[offset + i] =
185 T(softmax_exp(in[offset + i] - maxval) * normalizer);
186 }
187 }
188 }
189 }
190}
T ceildiv(T N, U M)
Compute ceil((float)N/(float)M)
Definition utils.h:272
T softmax_exp(T x)
Definition softmax.h:4
void softmax_single_row(const device T *in, device T *out, constant int &axis_size, uint gid, uint _lid, uint simd_lane_id, uint simd_group_id)
Definition softmax.h:11
void softmax_looped(const device T *in, device T *out, constant int &axis_size, uint gid, uint lid, uint lsize, uint simd_lane_id, uint simd_group_id)
Definition softmax.h:101
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
Definition utils.h:17