Option for precise softmax (#953)

* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-04-04 08:32:35 -07:00
committed by GitHub
parent 0caf35f4b8
commit e142aaf8a1
11 changed files with 215 additions and 99 deletions

View File

@@ -2430,12 +2430,13 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"softmax",
[](const array& a, const IntOrVec& axis, StreamOrDevice s) {
return softmax(a, get_reduce_axes(axis, a.ndim()), s);
[](const array& a, const IntOrVec& axis, bool precise, StreamOrDevice s) {
return softmax(a, get_reduce_axes(axis, a.ndim()), precise, s);
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"precise"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),