Option for precise softmax (#953)

* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-04-04 08:32:35 -07:00
committed by GitHub
parent 0caf35f4b8
commit e142aaf8a1
11 changed files with 215 additions and 99 deletions

View File

@@ -976,14 +976,16 @@ array rsqrt(const array& a, StreamOrDevice s = {});
array softmax(
const array& a,
const std::vector<int>& axes,
bool precise = false,
StreamOrDevice s = {});
/** Softmax of an array. */
array softmax(const array& a, StreamOrDevice s = {});
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
/** Softmax of an array. */
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
return softmax(a, std::vector<int>{axis}, s);
inline array
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
return softmax(a, std::vector<int>{axis}, precise, s);
}
/** Raise elements of a to the power of b element-wise */