mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
std and expm1 (#973)
* std and expm1 * actually add expm1 * fix linux * fix vjp * relax tol for linux test * Add it to the compilable primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
35
mlx/ops.h
35
mlx/ops.h
@@ -507,13 +507,14 @@ array mean(
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the mean of the elements of an array. */
|
||||
/** Computes the variance of the elements of an array. */
|
||||
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array var(const array& a, StreamOrDevice s = {}) {
|
||||
return var(a, false, 0, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the var of the elements of an array along the given axes */
|
||||
/** Computes the variance of the elements of an array along the given
|
||||
* axes */
|
||||
array var(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
@@ -521,7 +522,8 @@ array var(
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the var of the elements of an array along the given axis */
|
||||
/** Computes the variance of the elements of an array along the given
|
||||
* axis */
|
||||
array var(
|
||||
const array& a,
|
||||
int axis,
|
||||
@@ -529,6 +531,30 @@ array var(
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the standard deviation of the elements of an array. */
|
||||
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array std(const array& a, StreamOrDevice s = {}) {
|
||||
return std(a, false, 0, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the standard deviatoin of the elements of an array along the given
|
||||
* axes */
|
||||
array std(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims = false,
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the standard deviation of the elements of an array along the given
|
||||
* axis */
|
||||
array std(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** The product of all elements of the array. */
|
||||
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||
inline array prod(const array& a, StreamOrDevice s = {}) {
|
||||
@@ -842,6 +868,9 @@ array erf(const array& a, StreamOrDevice s = {});
|
||||
/** Computes the inverse error function of the elements of an array. */
|
||||
array erfinv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Computes the expm1 function of the elements of an array. */
|
||||
array expm1(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Stop the flow of gradients. */
|
||||
array stop_gradient(const array& a, StreamOrDevice s = {});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user