Option for precise softmax (#953)

* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-04-04 08:32:35 -07:00
committed by GitHub
parent 0caf35f4b8
commit e142aaf8a1
11 changed files with 215 additions and 99 deletions

View File

@@ -2619,25 +2619,34 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
array softmax(
const array& a,
const std::vector<int>& axes,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),
dtype,
std::make_shared<Softmax>(to_stream(s)),
std::make_shared<Softmax>(to_stream(s), precise),
{astype(a, dtype, s)});
} else {
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
auto ex = exp(subtract(a, a_max, s), s);
return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s);
auto in = a;
if (precise) {
in = astype(a, float32, s);
}
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
auto ex = exp(subtract(in, a_max, s), s);
return astype(
divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);
}
}
array softmax(const array& a, StreamOrDevice s /* = {}*/) {
array softmax(
const array& a,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return softmax(a, axes, s);
return softmax(a, axes, precise, s);
}
array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {