Option for precise softmax (#953)

* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-04-04 08:32:35 -07:00
committed by GitHub
parent 0caf35f4b8
commit e142aaf8a1
11 changed files with 215 additions and 99 deletions

View File

@@ -11,46 +11,48 @@ using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause it is gonna be x
// will be in (-oo, 0] anyway and subsequently it will be divided by
// sum(exp(x_i)).
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, int N_READS = SOFTMAX_N_READS>
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
T ld[N_READS];
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
ld[i] = in[i];
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
}
for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
: Limits<AccT>::finite_min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<T>::finite_min;
local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
T maxval = Limits<T>::finite_min;
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
@@ -69,9 +71,9 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
T normalizer = 0;
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
T exp_x = softmax_exp(ld[i] - maxval);
AccT exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
@@ -92,25 +94,23 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[i] = ld[i] * normalizer;
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = ld[i] * normalizer;
}
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = T(ld[i] * normalizer);
}
}
}
}
template <typename T, int N_READS = SOFTMAX_N_READS>
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@@ -118,22 +118,27 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
T prevmax;
T maxval = Limits<T>::finite_min;
T normalizer = 0;
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
T vals[N_READS];
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = in[offset + i];
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] =
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
@@ -179,50 +184,66 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
for (int i = 0; i < N_READS; i++) {
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
out[offset + i] =
T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
}
}
}
}
#define instantiate_softmax_single_row(name, itype) \
// clang-format off
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_looped(name, itype) \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax_precise(float16, half)
instantiate_softmax_precise(bfloat16, bfloat16_t)
// clang-format on

View File

@@ -56,6 +56,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
if (axis_size > looped_limit) {
op_name += "looped_";
}
if (in.dtype() != float32 && precise_) {
op_name += "precise_";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
{
@@ -82,8 +85,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(