std and expm1 (#973)

* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-04-08 14:26:01 -07:00
committed by GitHub
parent 76e63212ff
commit 42afe27e12
19 changed files with 332 additions and 6 deletions

View File

@@ -1430,6 +1430,34 @@ array var(
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}
array std(
const array& a,
bool keepdims,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return std(a, axes, keepdims, ddof, to_stream(s));
}
array std(
const array& a,
const std::vector<int>& axes,
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
return sqrt(var(a, axes, keepdims, ddof, s), s);
}
array std(
const array& a,
int axis,
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {} */) {
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
@@ -2033,6 +2061,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
}
array expm1(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});
}
array sin(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);