std and expm1 (#973)

* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-04-08 14:26:01 -07:00
committed by GitHub
parent 76e63212ff
commit 42afe27e12
19 changed files with 332 additions and 6 deletions

View File

@@ -507,13 +507,14 @@ array mean(
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array. */
/** Computes the variance of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {
return var(a, false, 0, to_stream(s));
}
/** Computes the var of the elements of an array along the given axes */
/** Computes the variance of the elements of an array along the given
* axes */
array var(
const array& a,
const std::vector<int>& axes,
@@ -521,7 +522,8 @@ array var(
int ddof = 0,
StreamOrDevice s = {});
/** Computes the var of the elements of an array along the given axis */
/** Computes the variance of the elements of an array along the given
* axis */
array var(
const array& a,
int axis,
@@ -529,6 +531,30 @@ array var(
int ddof = 0,
StreamOrDevice s = {});
/** Computes the standard deviation of the elements of an array. */
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array std(const array& a, StreamOrDevice s = {}) {
return std(a, false, 0, to_stream(s));
}
/** Computes the standard deviatoin of the elements of an array along the given
* axes */
array std(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** Computes the standard deviation of the elements of an array along the given
* axis */
array std(
const array& a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** The product of all elements of the array. */
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
inline array prod(const array& a, StreamOrDevice s = {}) {
@@ -842,6 +868,9 @@ array erf(const array& a, StreamOrDevice s = {});
/** Computes the inverse error function of the elements of an array. */
array erfinv(const array& a, StreamOrDevice s = {});
/** Computes the expm1 function of the elements of an array. */
array expm1(const array& a, StreamOrDevice s = {});
/** Stop the flow of gradients. */
array stop_gradient(const array& a, StreamOrDevice s = {});