Option for precise softmax (#953)

* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-04-04 08:32:35 -07:00 committed by GitHub
parent 0caf35f4b8
commit e142aaf8a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 215 additions and 99 deletions

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <limits> #include <limits>
@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
} }
}; };
template <typename T, typename VT, typename Ops, int N> template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
Ops ops; Ops ops;
@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity()); VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M; size_t s = M;
while (s >= N) { while (s >= N) {
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum); VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
current_in_ptr += N; current_in_ptr += N;
s -= N; s -= N;
} }
T maximum = ops.reduce_max(vmaximum); AccT maximum = ops.reduce_max(vmaximum);
while (s-- > 0) { while (s-- > 0) {
maximum = std::max(maximum, *current_in_ptr); maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++; current_in_ptr++;
} }
@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
s = M; s = M;
while (s >= N) { while (s >= N) {
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum)); VT vexp;
ops.store(current_out_ptr, vexp); if constexpr (std::is_same<T, AccT>::value) {
*(VT*)current_out_ptr = vexp; vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
vnormalizer = ops.add(vnormalizer, vexp); vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N; current_in_ptr += N;
current_out_ptr += N; current_out_ptr += N;
s -= N; s -= N;
} }
T normalizer = ops.reduce_add(vnormalizer); AccT normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) { while (s-- > 0) {
T _exp = std::exp(*current_in_ptr - maximum); AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp; if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
normalizer += _exp; normalizer += _exp;
current_in_ptr++; current_in_ptr++;
current_out_ptr++; current_out_ptr++;
@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
// Normalize // Normalize
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M; s = M;
while (s >= N) { while (s >= N) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer)); if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
current_out_ptr += N; current_out_ptr += N;
s -= N; s -= N;
} }
while (s-- > 0) { while (s-- > 0) {
*current_out_ptr *= normalizer; if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++; current_out_ptr++;
} }
} }
@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types"); "Softmax is defined only for floating point types");
break; break;
case float32: case float32:
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>( softmax<
in, out); float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
break; break;
case float16: case float16:
softmax< if (precise_) {
float16_t, softmax<
float16x8_t, float16_t,
NeonFp16SimdOps<float16_t, float16x8_t>, float,
8>(in, out); simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
}
break; break;
case bfloat16: case bfloat16:
eval(inputs, out); eval(inputs, out);

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -10,7 +10,7 @@ namespace mlx::core {
namespace { namespace {
template <typename T> template <typename T, typename AccT>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
const T* in_ptr = in.data<T>(); const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>(); T* out_ptr = out.data<T>();
@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) { for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
// Find the maximum // Find the maximum
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
T maximum = *current_in_ptr; AccT maximum = *current_in_ptr;
for (int j = 0; j < N; j++, current_in_ptr++) { for (int j = 0; j < N; j++, current_in_ptr++) {
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum; maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
: maximum;
} }
// Compute the normalizer and the exponentials // Compute the normalizer and the exponentials
T normalizer = 0; AccT normalizer = 0;
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
current_in_ptr = in_ptr; current_in_ptr = in_ptr;
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) { for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
T expv = std::exp(*current_in_ptr - maximum); AccT expv = std::exp(*current_in_ptr - maximum);
normalizer += expv; normalizer += expv;
*current_out_ptr = expv; if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr = expv;
}
} }
normalizer = 1 / normalizer; normalizer = 1 / normalizer;
// Normalize // Normalize
current_in_ptr = in_ptr;
current_out_ptr = out_ptr; current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) { for (int j = 0; j < N; j++, current_out_ptr++) {
*current_out_ptr *= normalizer; if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
auto v = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(v * normalizer);
current_in_ptr++;
}
} }
} }
} }
@ -91,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types"); "Softmax is defined only for floating point types");
break; break;
case float32: case float32:
softmax<float>(in, out); softmax<float, float>(in, out);
break; break;
case float16: case float16:
softmax<float16_t>(in, out); if (precise_) {
softmax<float16_t, float>(in, out);
} else {
softmax<float16_t, float16_t>(in, out);
}
break; break;
case bfloat16: case bfloat16:
softmax<bfloat16_t>(in, out); if (precise_) {
softmax<bfloat16_t, float>(in, out);
} else {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break; break;
case complex64: case complex64:
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -11,46 +11,48 @@ using namespace metal;
template <typename T> template <typename T>
inline T softmax_exp(T x) { inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause it is gonna be x // Softmax doesn't need high precision exponential cause x is gonna be in
// will be in (-oo, 0] anyway and subsequently it will be divided by // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
// sum(exp(x_i)).
return fast::exp(x); return fast::exp(x);
} }
template <typename T, int N_READS = SOFTMAX_N_READS> template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row( [[kernel]] void softmax_single_row(
const device T* in, const device T* in,
device T* out, device T* out,
constant int& axis_size, constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]], uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid; int lid = _lid;
T ld[N_READS]; constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS; in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) { for (int i = 0; i < N_READS; i++) {
ld[i] = in[i]; ld[i] = AccT(in[i]);
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
ld[i] = ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min); : Limits<AccT>::finite_min;
} }
} }
if (simd_group_id == 0) { if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<T>::finite_min; local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_normalizer[simd_lane_id] = 0; local_normalizer[simd_lane_id] = 0;
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max // Get the max
T maxval = Limits<T>::finite_min; AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval; maxval = (maxval < ld[i]) ? ld[i] : maxval;
} }
@ -69,9 +71,9 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
maxval = local_max[0]; maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer // Compute exp(x_i - maxval) and store the partial sums in local_normalizer
T normalizer = 0; AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
T exp_x = softmax_exp(ld[i] - maxval); AccT exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x; ld[i] = exp_x;
normalizer += exp_x; normalizer += exp_x;
} }
@ -92,25 +94,23 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
// Normalize and write to the output // Normalize and write to the output
out += gid * axis_size + lid * N_READS; out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) { for (int i = 0; i < N_READS; i++) {
out[i] = ld[i] * normalizer; out[i] = T(ld[i] * normalizer);
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) { if ((lid * N_READS + i) < axis_size) {
out[i] = ld[i] * normalizer; out[i] = T(ld[i] * normalizer);
}
} }
}
} }
} }
template <typename T, int N_READS = SOFTMAX_N_READS> template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped( [[kernel]] void softmax_looped(
const device T* in, const device T* in,
device T* out, device T* out,
constant int& axis_size, constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
@ -118,22 +118,27 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size; in += gid * axis_size;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go // Get the max and the normalizer in one go
T prevmax; AccT prevmax;
T maxval = Limits<T>::finite_min; AccT maxval = Limits<AccT>::finite_min;
T normalizer = 0; AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize)); for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) { r++) {
int offset = r * lsize * N_READS + lid * N_READS; int offset = r * lsize * N_READS + lid * N_READS;
T vals[N_READS]; AccT vals[N_READS];
if (offset + N_READS <= axis_size) { if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = in[offset + i]; vals[i] = AccT(in[offset + i]);
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min); : Limits<AccT>::finite_min;
} }
} }
prevmax = maxval; prevmax = maxval;
@ -179,50 +184,66 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
r++) { r++) {
int offset = r * lsize * N_READS + lid * N_READS; int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) { if (offset + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) { for (int i = 0; i < N_READS; i++) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer; out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) { if (offset + i < axis_size) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer; out[offset + i] =
T(softmax_exp(in[offset + i] - maxval) * normalizer);
} }
} }
} }
} }
} }
#define instantiate_softmax_single_row(name, itype) \ // clang-format off
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \ template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \ softmax_single_row<itype>( \
const device itype* in, \ const device itype* in, \
device itype* out, \ device itype* out, \
constant int& axis_size, \ constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \ uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
#define instantiate_softmax_looped(name, itype) \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \ template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \ softmax_looped<itype>( \
const device itype* in, \ const device itype* in, \
device itype* out, \ device itype* out, \
constant int& axis_size, \ constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[threadgroup_position_in_grid]], \ uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \ #define instantiate_softmax_precise(name, itype) \
instantiate_softmax_single_row(name, itype) \ template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
instantiate_softmax_looped(name, itype) softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_softmax(float32, float) instantiate_softmax(float32, float)
instantiate_softmax(float16, half) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t) instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax_precise(float16, half)
instantiate_softmax_precise(bfloat16, bfloat16_t)
// clang-format on

View File

@ -56,6 +56,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
if (axis_size > looped_limit) { if (axis_size > looped_limit) {
op_name += "looped_"; op_name += "looped_";
} }
if (in.dtype() != float32 && precise_) {
op_name += "precise_";
}
op_name += type_to_name(out); op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
{ {
@ -82,8 +85,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
} }
d.get_command_buffer(s.index)->addCompletedHandler( d.get_command_buffer(s.index)->addCompletedHandler(

View File

@ -550,10 +550,7 @@ array scaled_dot_product_attention(
if (needs_mask) { if (needs_mask) {
scores = add(scores, inputs[3], s); scores = add(scores, inputs[3], s);
} }
scores = astype( scores = softmax(scores, std::vector<int>{-1}, true, s);
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
final_type,
s);
auto out = matmul(scores, v, s); auto out = matmul(scores, v, s);
if (n_repeats > 1) { if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s); out = reshape(out, {B, n_q_heads, L, -1}, s);

View File

@ -2619,25 +2619,34 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
array softmax( array softmax(
const array& a, const array& a,
const std::vector<int>& axes, const std::vector<int>& axes,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) { if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_shared<Softmax>(to_stream(s)), std::make_shared<Softmax>(to_stream(s), precise),
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} else { } else {
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s); auto in = a;
auto ex = exp(subtract(a, a_max, s), s); if (precise) {
return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s); in = astype(a, float32, s);
}
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
auto ex = exp(subtract(in, a_max, s), s);
return astype(
divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);
} }
} }
array softmax(const array& a, StreamOrDevice s /* = {}*/) { array softmax(
const array& a,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim()); std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
return softmax(a, axes, s); return softmax(a, axes, precise, s);
} }
array power(const array& a, const array& b, StreamOrDevice s /* = {} */) { array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {

View File

@ -976,14 +976,16 @@ array rsqrt(const array& a, StreamOrDevice s = {});
array softmax( array softmax(
const array& a, const array& a,
const std::vector<int>& axes, const std::vector<int>& axes,
bool precise = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Softmax of an array. */ /** Softmax of an array. */
array softmax(const array& a, StreamOrDevice s = {}); array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
/** Softmax of an array. */ /** Softmax of an array. */
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) { inline array
return softmax(a, std::vector<int>{axis}, s); softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
return softmax(a, std::vector<int>{axis}, precise, s);
} }
/** Raise elements of a to the power of b element-wise */ /** Raise elements of a to the power of b element-wise */

View File

@ -2975,7 +2975,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
} else { } else {
softmax_axes.push_back(-2); softmax_axes.push_back(-2);
} }
return {{softmax(inputs[0], softmax_axes, stream())}, axes}; return {{softmax(inputs[0], softmax_axes, precise_, stream())}, axes};
} }
std::vector<array> Softmax::vjp( std::vector<array> Softmax::vjp(
@ -2998,13 +2998,18 @@ std::vector<array> Softmax::jvp(
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
assert(primals.size() == 1); assert(primals.size() == 1);
assert(tangents.size() == 1); assert(tangents.size() == 1);
auto s = softmax(primals[0], std::vector<int>{-1}, stream()); auto s = softmax(primals[0], std::vector<int>{-1}, precise_, stream());
auto sv = multiply(s, tangents[0], stream()); auto sv = multiply(s, tangents[0], stream());
return {subtract( return {subtract(
sv, sv,
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))}; multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
} }
bool Softmax::is_equivalent(const Primitive& other) const {
const Softmax& s_other = static_cast<const Softmax&>(other);
return precise_ == s_other.precise_;
}
std::pair<std::vector<array>, std::vector<int>> Sort::vmap( std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {

View File

@ -1702,7 +1702,8 @@ class SliceUpdate : public UnaryPrimitive {
class Softmax : public UnaryPrimitive { class Softmax : public UnaryPrimitive {
public: public:
explicit Softmax(Stream stream) : UnaryPrimitive(stream){}; explicit Softmax(Stream stream, bool precise)
: UnaryPrimitive(stream), precise_(precise){};
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1710,11 +1711,13 @@ class Softmax : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Softmax) DEFINE_PRINT(Softmax)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
bool precise_;
}; };
class Sort : public UnaryPrimitive { class Sort : public UnaryPrimitive {

View File

@ -2430,12 +2430,13 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"softmax", "softmax",
[](const array& a, const IntOrVec& axis, StreamOrDevice s) { [](const array& a, const IntOrVec& axis, bool precise, StreamOrDevice s) {
return softmax(a, get_reduce_axes(axis, a.ndim()), s); return softmax(a, get_reduce_axes(axis, a.ndim()), precise, s);
}, },
nb::arg(), nb::arg(),
"axis"_a = nb::none(), "axis"_a = nb::none(),
nb::kw_only(), nb::kw_only(),
"precise"_a = false,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), "def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),

View File

@ -1430,6 +1430,13 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.softmax(y[:, 0:2], axis=-1) out = mx.softmax(y[:, 0:2], axis=-1)
self.assertAlmostEqual(out.sum().item(), 8.0, 5) self.assertAlmostEqual(out.sum().item(), 8.0, 5)
# Precise
for t in [mx.float16, mx.bfloat16]:
a = (10 * mx.random.normal(shape=(1024,))).astype(t)
out_expect = mx.softmax(a.astype(mx.float32)).astype(t)
out = mx.softmax(a, axis=-1, precise=True)
self.assertTrue(mx.allclose(out_expect, out))
def test_concatenate(self): def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32) a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32)