mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Option for precise softmax (#953)
* precise softmax * Add an equivalency check * Make the threadgroup memory definition fixed * precise cpu softmax * precise option on cpu * remove print --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -976,14 +976,16 @@ array rsqrt(const array& a, StreamOrDevice s = {});
|
||||
array softmax(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool precise = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
array softmax(const array& a, StreamOrDevice s = {});
|
||||
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, s);
|
||||
inline array
|
||||
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, precise, s);
|
||||
}
|
||||
|
||||
/** Raise elements of a to the power of b element-wise */
|
||||
|
||||
Reference in New Issue
Block a user