mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Option for precise softmax (#953)
* precise softmax * Add an equivalency check * Make the threadgroup memory definition fixed * precise cpu softmax * precise option on cpu * remove print --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
0caf35f4b8
commit
e142aaf8a1
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, typename Ops, int N>
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T maximum = ops.reduce_max(vmaximum);
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, *current_in_ptr);
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
|
||||
ops.store(current_out_ptr, vexp);
|
||||
*(VT*)current_out_ptr = vexp;
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T normalizer = ops.reduce_add(vnormalizer);
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
T _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = _exp;
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
|
||||
in, out);
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
T maximum = *current_in_ptr;
|
||||
AccT maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
|
||||
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
||||
: maximum;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
T expv = std::exp(*current_in_ptr - maximum);
|
||||
AccT expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
*current_out_ptr = expv;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = expv;
|
||||
}
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_in_ptr = in_ptr;
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
auto v = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(v * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -91,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float>(in, out);
|
||||
softmax<float, float>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<float16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<float16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<float16_t, float16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
softmax<bfloat16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<bfloat16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
|
@ -11,46 +11,48 @@ using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||
// sum(exp(x_i)).
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
T ld[N_READS];
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
ld[i] = in[i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||
local_max[simd_lane_id] = Limits<AccT>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
T maxval = Limits<T>::finite_min;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
@ -69,9 +71,9 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T exp_x = softmax_exp(ld[i] - maxval);
|
||||
AccT exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
@ -92,25 +94,23 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@ -118,22 +118,27 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
T prevmax;
|
||||
T maxval = Limits<T>::finite_min;
|
||||
T normalizer = 0;
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
T vals[N_READS];
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[offset + i];
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
@ -179,50 +184,66 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
out[offset + i] =
|
||||
T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_softmax_single_row(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_looped(name, itype) \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax_precise(float16, half)
|
||||
instantiate_softmax_precise(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
@ -56,6 +56,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "looped_";
|
||||
}
|
||||
if (in.dtype() != float32 && precise_) {
|
||||
op_name += "precise_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
@ -82,8 +85,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
@ -550,10 +550,7 @@ array scaled_dot_product_attention(
|
||||
if (needs_mask) {
|
||||
scores = add(scores, inputs[3], s);
|
||||
}
|
||||
scores = astype(
|
||||
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
|
||||
final_type,
|
||||
s);
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||
|
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -2619,25 +2619,34 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
|
||||
array softmax(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
a.shape(),
|
||||
dtype,
|
||||
std::make_shared<Softmax>(to_stream(s)),
|
||||
std::make_shared<Softmax>(to_stream(s), precise),
|
||||
{astype(a, dtype, s)});
|
||||
} else {
|
||||
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
|
||||
auto ex = exp(subtract(a, a_max, s), s);
|
||||
return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s);
|
||||
auto in = a;
|
||||
if (precise) {
|
||||
in = astype(a, float32, s);
|
||||
}
|
||||
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
|
||||
auto ex = exp(subtract(in, a_max, s), s);
|
||||
return astype(
|
||||
divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);
|
||||
}
|
||||
}
|
||||
|
||||
array softmax(const array& a, StreamOrDevice s /* = {}*/) {
|
||||
array softmax(
|
||||
const array& a,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return softmax(a, axes, s);
|
||||
return softmax(a, axes, precise, s);
|
||||
}
|
||||
|
||||
array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
@ -976,14 +976,16 @@ array rsqrt(const array& a, StreamOrDevice s = {});
|
||||
array softmax(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool precise = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
array softmax(const array& a, StreamOrDevice s = {});
|
||||
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, s);
|
||||
inline array
|
||||
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, precise, s);
|
||||
}
|
||||
|
||||
/** Raise elements of a to the power of b element-wise */
|
||||
|
@ -2975,7 +2975,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||
} else {
|
||||
softmax_axes.push_back(-2);
|
||||
}
|
||||
return {{softmax(inputs[0], softmax_axes, stream())}, axes};
|
||||
return {{softmax(inputs[0], softmax_axes, precise_, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Softmax::vjp(
|
||||
@ -2998,13 +2998,18 @@ std::vector<array> Softmax::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(tangents.size() == 1);
|
||||
auto s = softmax(primals[0], std::vector<int>{-1}, stream());
|
||||
auto s = softmax(primals[0], std::vector<int>{-1}, precise_, stream());
|
||||
auto sv = multiply(s, tangents[0], stream());
|
||||
return {subtract(
|
||||
sv,
|
||||
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
|
||||
}
|
||||
|
||||
bool Softmax::is_equivalent(const Primitive& other) const {
|
||||
const Softmax& s_other = static_cast<const Softmax&>(other);
|
||||
return precise_ == s_other.precise_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
@ -1702,7 +1702,8 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
|
||||
class Softmax : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Softmax(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Softmax(Stream stream, bool precise)
|
||||
: UnaryPrimitive(stream), precise_(precise){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1710,11 +1711,13 @@ class Softmax : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Softmax)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool precise_;
|
||||
};
|
||||
|
||||
class Sort : public UnaryPrimitive {
|
||||
|
@ -2430,12 +2430,13 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"softmax",
|
||||
[](const array& a, const IntOrVec& axis, StreamOrDevice s) {
|
||||
return softmax(a, get_reduce_axes(axis, a.ndim()), s);
|
||||
[](const array& a, const IntOrVec& axis, bool precise, StreamOrDevice s) {
|
||||
return softmax(a, get_reduce_axes(axis, a.ndim()), precise, s);
|
||||
},
|
||||
nb::arg(),
|
||||
"axis"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"precise"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
|
@ -1430,6 +1430,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.softmax(y[:, 0:2], axis=-1)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0, 5)
|
||||
|
||||
# Precise
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
a = (10 * mx.random.normal(shape=(1024,))).astype(t)
|
||||
out_expect = mx.softmax(a.astype(mx.float32)).astype(t)
|
||||
out = mx.softmax(a, axis=-1, precise=True)
|
||||
self.assertTrue(mx.allclose(out_expect, out))
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
b_npy = np.random.randn(32, 32, 32)
|
||||
|
Loading…
Reference in New Issue
Block a user