mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
std and expm1 (#973)
* std and expm1 * actually add expm1 * fix linux * fix vjp * relax tol for linux test * Add it to the compilable primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
35
mlx/ops.cpp
35
mlx/ops.cpp
@@ -1430,6 +1430,34 @@ array var(
|
||||
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
bool keepdims,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return std(a, axes, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
return sqrt(var(a, axes, keepdims, ddof, s), s);
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims /* = false */,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
@@ -2033,6 +2061,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array expm1(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
return array(
|
||||
a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
|
||||
Reference in New Issue
Block a user