mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Option for precise softmax (#953)
* precise softmax * Add an equivalency check * Make the threadgroup memory definition fixed * precise cpu softmax * precise option on cpu * remove print --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
21
mlx/ops.cpp
21
mlx/ops.cpp
@@ -2619,25 +2619,34 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
|
||||
array softmax(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
a.shape(),
|
||||
dtype,
|
||||
std::make_shared<Softmax>(to_stream(s)),
|
||||
std::make_shared<Softmax>(to_stream(s), precise),
|
||||
{astype(a, dtype, s)});
|
||||
} else {
|
||||
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
|
||||
auto ex = exp(subtract(a, a_max, s), s);
|
||||
return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s);
|
||||
auto in = a;
|
||||
if (precise) {
|
||||
in = astype(a, float32, s);
|
||||
}
|
||||
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
|
||||
auto ex = exp(subtract(in, a_max, s), s);
|
||||
return astype(
|
||||
divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);
|
||||
}
|
||||
}
|
||||
|
||||
array softmax(const array& a, StreamOrDevice s /* = {}*/) {
|
||||
array softmax(
|
||||
const array& a,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return softmax(a, axes, s);
|
||||
return softmax(a, axes, precise, s);
|
||||
}
|
||||
|
||||
array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
||||
Reference in New Issue
Block a user