mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
std and expm1 (#973)
* std and expm1 * actually add expm1 * fix linux * fix vjp * relax tol for linux test * Add it to the compilable primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
76e63212ff
commit
42afe27e12
@ -5,13 +5,13 @@ Operations
|
|||||||
|
|
||||||
.. currentmodule:: mlx.core
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
abs
|
abs
|
||||||
add
|
add
|
||||||
all
|
all
|
||||||
allclose
|
allclose
|
||||||
any
|
any
|
||||||
arange
|
arange
|
||||||
arccos
|
arccos
|
||||||
@ -51,6 +51,7 @@ Operations
|
|||||||
erf
|
erf
|
||||||
erfinv
|
erfinv
|
||||||
exp
|
exp
|
||||||
|
expm1
|
||||||
expand_dims
|
expand_dims
|
||||||
eye
|
eye
|
||||||
flatten
|
flatten
|
||||||
@ -117,6 +118,7 @@ Operations
|
|||||||
square
|
square
|
||||||
squeeze
|
squeeze
|
||||||
stack
|
stack
|
||||||
|
std
|
||||||
stop_gradient
|
stop_gradient
|
||||||
subtract
|
subtract
|
||||||
sum
|
sum
|
||||||
|
@ -310,6 +310,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
|
auto size = in.data_size();
|
||||||
|
vvexpm1f(
|
||||||
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
@ -57,6 +57,7 @@ DEFAULT(Equal)
|
|||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
DEFAULT(Exp)
|
DEFAULT(Exp)
|
||||||
|
DEFAULT(Expm1)
|
||||||
DEFAULT(FFT)
|
DEFAULT(FFT)
|
||||||
DEFAULT(Floor)
|
DEFAULT(Floor)
|
||||||
DEFAULT(Full)
|
DEFAULT(Full)
|
||||||
|
@ -241,6 +241,13 @@ struct Exp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Expm1 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return expm1(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
struct Floor {
|
struct Floor {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
|
@ -359,6 +359,18 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
|
unary_fp(in, out, detail::Expm1());
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[expm1] Cannot exponentiate elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
@ -7,6 +7,7 @@ set(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||||
|
89
mlx/backend/metal/kernels/expm1f.h
Normal file
89
mlx/backend/metal/kernels/expm1f.h
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
// Original license copied below:
|
||||||
|
// Copyright (c) 2015-2023 Norbert Juffa
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions
|
||||||
|
// are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright
|
||||||
|
// notice, this list of conditions and the following disclaimer.
|
||||||
|
//
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
// notice, this list of conditions and the following disclaimer in the
|
||||||
|
// documentation and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
|
||||||
|
|
||||||
|
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
|
||||||
|
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
|
||||||
|
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
|
||||||
|
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
|
||||||
|
|
||||||
|
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
|
||||||
|
*/
|
||||||
|
float expm1f_scaled_unchecked(float a, float b) {
|
||||||
|
float f, j, r, s, t, u, v, x, y;
|
||||||
|
int i;
|
||||||
|
|
||||||
|
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
|
||||||
|
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
|
||||||
|
j = j - 12582912.0f; // 0x1.8p23
|
||||||
|
i = (int)j;
|
||||||
|
f = fma(j, -6.93145752e-1f, a);
|
||||||
|
|
||||||
|
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
|
||||||
|
s = f * f;
|
||||||
|
if (a == 0.0f)
|
||||||
|
s = a; // ensure -0 is passed through
|
||||||
|
// err = 0.997458 ulp1 = 11081805
|
||||||
|
r = 1.97350979e-4f; // 0x1.9de000p-13
|
||||||
|
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
|
||||||
|
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
|
||||||
|
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
|
||||||
|
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
|
||||||
|
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
|
||||||
|
u = (j == 1) ? (f + 0.5f) : f;
|
||||||
|
v = fma(r, s, u);
|
||||||
|
s = 0.5f * b;
|
||||||
|
t = ldexp(s, i);
|
||||||
|
y = t - s;
|
||||||
|
x = (t - y) - s; // double-float canonicalization of difference
|
||||||
|
r = fma(v, t, x) + y;
|
||||||
|
r = r + r;
|
||||||
|
if (j == 0)
|
||||||
|
r = v;
|
||||||
|
if (j == 1)
|
||||||
|
r = v + v;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
|
||||||
|
float expm1f(float a) {
|
||||||
|
float r;
|
||||||
|
|
||||||
|
r = expm1f_scaled_unchecked(a, 1.0f);
|
||||||
|
/* handle severe overflow and underflow */
|
||||||
|
if (abs(a - 1.0f) > 88.0f) {
|
||||||
|
r = fma(r, r, -1.0f);
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/erf.h"
|
#include "mlx/backend/metal/kernels/erf.h"
|
||||||
|
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -183,6 +184,13 @@ struct Exp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Expm1 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return static_cast<T>(expm1f(static_cast<float>(x)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
struct Floor {
|
struct Floor {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
|
@ -71,6 +71,7 @@ instantiate_unary_types(ceil, Ceil)
|
|||||||
instantiate_unary_float(cos, Cos)
|
instantiate_unary_float(cos, Cos)
|
||||||
instantiate_unary_float(cosh, Cosh)
|
instantiate_unary_float(cosh, Cosh)
|
||||||
instantiate_unary_float(exp, Exp)
|
instantiate_unary_float(exp, Exp)
|
||||||
|
instantiate_unary_float(expm1, Expm1)
|
||||||
instantiate_unary_types(floor, Floor)
|
instantiate_unary_types(floor, Floor)
|
||||||
instantiate_unary_float(log, Log)
|
instantiate_unary_float(log, Log)
|
||||||
instantiate_unary_float(log2, Log2)
|
instantiate_unary_float(log2, Log2)
|
||||||
|
@ -615,6 +615,10 @@ void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
unary_op(inputs, out, "exp");
|
unary_op(inputs, out, "exp");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "expm1");
|
||||||
|
}
|
||||||
|
|
||||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
CopyType ctype;
|
CopyType ctype;
|
||||||
|
@ -49,6 +49,7 @@ NO_GPU(Equal)
|
|||||||
NO_GPU(Erf)
|
NO_GPU(Erf)
|
||||||
NO_GPU(ErfInv)
|
NO_GPU(ErfInv)
|
||||||
NO_GPU(Exp)
|
NO_GPU(Exp)
|
||||||
|
NO_GPU(Expm1)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
NO_GPU(Floor)
|
NO_GPU(Floor)
|
||||||
NO_GPU(Full)
|
NO_GPU(Full)
|
||||||
|
@ -32,7 +32,7 @@ bool is_unary(const Primitive& p) {
|
|||||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||||
typeid(p) == typeid(Tanh));
|
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_binary(const Primitive& p) {
|
bool is_binary(const Primitive& p) {
|
||||||
|
35
mlx/ops.cpp
35
mlx/ops.cpp
@ -1430,6 +1430,34 @@ array var(
|
|||||||
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array std(
|
||||||
|
const array& a,
|
||||||
|
bool keepdims,
|
||||||
|
int ddof /* = 0*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
std::vector<int> axes(a.ndim());
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
return std(a, axes, keepdims, ddof, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
array std(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
int ddof /* = 0*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
return sqrt(var(a, axes, keepdims, ddof, s), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array std(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
int ddof /* = 0*/,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||||
std::vector<int> axes(a.ndim());
|
std::vector<int> axes(a.ndim());
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
@ -2033,6 +2061,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array expm1(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
auto dtype = at_least_float(a.dtype());
|
||||||
|
auto input = astype(a, dtype, s);
|
||||||
|
return array(
|
||||||
|
a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});
|
||||||
|
}
|
||||||
|
|
||||||
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
|
35
mlx/ops.h
35
mlx/ops.h
@ -507,13 +507,14 @@ array mean(
|
|||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Computes the mean of the elements of an array. */
|
/** Computes the variance of the elements of an array. */
|
||||||
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||||
inline array var(const array& a, StreamOrDevice s = {}) {
|
inline array var(const array& a, StreamOrDevice s = {}) {
|
||||||
return var(a, false, 0, to_stream(s));
|
return var(a, false, 0, to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Computes the var of the elements of an array along the given axes */
|
/** Computes the variance of the elements of an array along the given
|
||||||
|
* axes */
|
||||||
array var(
|
array var(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
@ -521,7 +522,8 @@ array var(
|
|||||||
int ddof = 0,
|
int ddof = 0,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Computes the var of the elements of an array along the given axis */
|
/** Computes the variance of the elements of an array along the given
|
||||||
|
* axis */
|
||||||
array var(
|
array var(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@ -529,6 +531,30 @@ array var(
|
|||||||
int ddof = 0,
|
int ddof = 0,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the standard deviation of the elements of an array. */
|
||||||
|
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||||
|
inline array std(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return std(a, false, 0, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Computes the standard deviatoin of the elements of an array along the given
|
||||||
|
* axes */
|
||||||
|
array std(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
bool keepdims = false,
|
||||||
|
int ddof = 0,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the standard deviation of the elements of an array along the given
|
||||||
|
* axis */
|
||||||
|
array std(
|
||||||
|
const array& a,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
int ddof = 0,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** The product of all elements of the array. */
|
/** The product of all elements of the array. */
|
||||||
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
inline array prod(const array& a, StreamOrDevice s = {}) {
|
inline array prod(const array& a, StreamOrDevice s = {}) {
|
||||||
@ -842,6 +868,9 @@ array erf(const array& a, StreamOrDevice s = {});
|
|||||||
/** Computes the inverse error function of the elements of an array. */
|
/** Computes the inverse error function of the elements of an array. */
|
||||||
array erfinv(const array& a, StreamOrDevice s = {});
|
array erfinv(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes the expm1 function of the elements of an array. */
|
||||||
|
array expm1(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Stop the flow of gradients. */
|
/** Stop the flow of gradients. */
|
||||||
array stop_gradient(const array& a, StreamOrDevice s = {});
|
array stop_gradient(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -1239,6 +1239,34 @@ std::pair<std::vector<array>, std::vector<int>> Exp::vmap(
|
|||||||
return {{exp(inputs[0], stream())}, axes};
|
return {{exp(inputs[0], stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Expm1::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
|
return {multiply(
|
||||||
|
cotangents[0],
|
||||||
|
add(outputs[0], array(1.0f, outputs[0].dtype()), stream()),
|
||||||
|
stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Expm1::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return {multiply(tangents[0], exp(primals[0], stream()), stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(axes.size() == 1);
|
||||||
|
return {{expm1(inputs[0], stream())}, axes};
|
||||||
|
}
|
||||||
|
|
||||||
bool FFT::is_equivalent(const Primitive& other) const {
|
bool FFT::is_equivalent(const Primitive& other) const {
|
||||||
const FFT& r_other = static_cast<const FFT&>(other);
|
const FFT& r_other = static_cast<const FFT&>(other);
|
||||||
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
||||||
|
@ -837,6 +837,22 @@ class Exp : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Expm1 : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit Expm1(Stream stream) : UnaryPrimitive(stream){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Expm1)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class FFT : public UnaryPrimitive {
|
class FFT : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit FFT(
|
explicit FFT(
|
||||||
|
@ -772,6 +772,25 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The exponential of ``a``.
|
array: The exponential of ``a``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"expm1",
|
||||||
|
&mlx::core::expm1,
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def expm1(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Element-wise exponential minus 1.
|
||||||
|
|
||||||
|
Computes ``exp(x) - 1`` with greater precision for small ``x``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The expm1 of ``a``.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"erf",
|
"erf",
|
||||||
&mlx::core::erf,
|
&mlx::core::erf,
|
||||||
@ -2150,6 +2169,40 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The output array of variances.
|
array: The output array of variances.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"std",
|
||||||
|
[](const array& a,
|
||||||
|
const IntOrVec& axis,
|
||||||
|
bool keepdims,
|
||||||
|
int ddof,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
return mlx::core::std(
|
||||||
|
a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
"axis"_a = nb::none(),
|
||||||
|
"keepdims"_a = false,
|
||||||
|
"ddof"_a = 0,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def std(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Compute the standard deviation(s) over the given axes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
axis (int or list(int), optional): Optional axis or
|
||||||
|
axes to reduce over. If unspecified this defaults
|
||||||
|
to reducing over the entire array.
|
||||||
|
keepdims (bool, optional): Keep reduced axes as
|
||||||
|
singleton dimensions, defaults to `False`.
|
||||||
|
ddof (int, optional): The divisor to compute the variance
|
||||||
|
is ``N - ddof``, defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array of standard deviations.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"split",
|
"split",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
|
@ -725,6 +725,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = mx.var(x, ddof=3)
|
out = mx.var(x, ddof=3)
|
||||||
self.assertEqual(out.item(), float("inf"))
|
self.assertEqual(out.item(), float("inf"))
|
||||||
|
|
||||||
|
def test_std(self):
|
||||||
|
x = mx.random.uniform(shape=(5, 5))
|
||||||
|
x_np = np.array(x)
|
||||||
|
self.assertAlmostEqual(mx.std(x).item(), x_np.std().item(), places=6)
|
||||||
|
|
||||||
def test_abs(self):
|
def test_abs(self):
|
||||||
a = mx.array([-1.0, 1.0, -2.0, 3.0])
|
a = mx.array([-1.0, 1.0, -2.0, 3.0])
|
||||||
result = mx.abs(a)
|
result = mx.abs(a)
|
||||||
@ -839,6 +844,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(np.allclose(result, expected))
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
def test_expm1(self):
|
||||||
|
a = mx.array([0, 0.5, -0.5, 5])
|
||||||
|
result = mx.expm1(a)
|
||||||
|
expected = np.expm1(a, dtype=np.float32)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
|
||||||
|
|
||||||
def test_erf(self):
|
def test_erf(self):
|
||||||
inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0]
|
inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0]
|
||||||
x = mx.array(inputs)
|
x = mx.array(inputs)
|
||||||
|
@ -1092,6 +1092,20 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
CHECK(allclose(exp(x), expected).item<bool>());
|
CHECK(allclose(exp(x), expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test expm1
|
||||||
|
{
|
||||||
|
array x(-1.0f);
|
||||||
|
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(-1.0f)));
|
||||||
|
|
||||||
|
x = array(1.0f);
|
||||||
|
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
|
||||||
|
|
||||||
|
// Integer input type
|
||||||
|
x = array(1);
|
||||||
|
CHECK_EQ(expm1(x).dtype(), float32);
|
||||||
|
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
|
||||||
|
}
|
||||||
|
|
||||||
// Test sine
|
// Test sine
|
||||||
{
|
{
|
||||||
array x(0.0);
|
array x(0.0);
|
||||||
|
Loading…
Reference in New Issue
Block a user