std and expm1 (#973)

* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-04-08 14:26:01 -07:00 committed by GitHub
parent 76e63212ff
commit 42afe27e12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 332 additions and 6 deletions

View File

@ -5,13 +5,13 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
all
allclose
allclose
any
arange
arccos
@ -51,6 +51,7 @@ Operations
erf
erfinv
exp
expm1
expand_dims
eye
flatten
@ -117,6 +118,7 @@ Operations
square
squeeze
stack
std
stop_gradient
subtract
sum

View File

@ -310,6 +310,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@ -57,6 +57,7 @@ DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)

View File

@ -241,6 +241,13 @@ struct Exp {
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
};
};
struct Floor {
template <typename T>
T operator()(T x) {

View File

@ -359,6 +359,18 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Expm1());
} else {
throw std::invalid_argument(
"[expm1] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@ -7,6 +7,7 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h

View File

@ -0,0 +1,89 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_math>
// Original license copied below:
// Copyright (c) 2015-2023 Norbert Juffa
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
*/
float expm1f_scaled_unchecked(float a, float b) {
float f, j, r, s, t, u, v, x, y;
int i;
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
j = j - 12582912.0f; // 0x1.8p23
i = (int)j;
f = fma(j, -6.93145752e-1f, a);
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
s = f * f;
if (a == 0.0f)
s = a; // ensure -0 is passed through
// err = 0.997458 ulp1 = 11081805
r = 1.97350979e-4f; // 0x1.9de000p-13
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
u = (j == 1) ? (f + 0.5f) : f;
v = fma(r, s, u);
s = 0.5f * b;
t = ldexp(s, i);
y = t - s;
x = (t - y) - s; // double-float canonicalization of difference
r = fma(v, t, x) + y;
r = r + r;
if (j == 0)
r = v;
if (j == 1)
r = v + v;
return r;
}
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
float expm1f(float a) {
float r;
r = expm1f_scaled_unchecked(a, 1.0f);
/* handle severe overflow and underflow */
if (abs(a - 1.0f) > 88.0f) {
r = fma(r, r, -1.0f);
}
return r;
}

View File

@ -7,6 +7,7 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/expm1f.h"
#include "mlx/backend/metal/kernels/utils.h"
namespace {
@ -183,6 +184,13 @@ struct Exp {
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return static_cast<T>(expm1f(static_cast<float>(x)));
};
};
struct Floor {
template <typename T>
T operator()(T x) {

View File

@ -71,6 +71,7 @@ instantiate_unary_types(ceil, Ceil)
instantiate_unary_float(cos, Cos)
instantiate_unary_float(cosh, Cosh)
instantiate_unary_float(exp, Exp)
instantiate_unary_float(expm1, Expm1)
instantiate_unary_types(floor, Floor)
instantiate_unary_float(log, Log)
instantiate_unary_float(log2, Log2)

View File

@ -615,6 +615,10 @@ void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "exp");
}
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "expm1");
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;

View File

@ -49,6 +49,7 @@ NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Full)

View File

@ -32,7 +32,7 @@ bool is_unary(const Primitive& p) {
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
typeid(p) == typeid(Tanh));
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
}
bool is_binary(const Primitive& p) {

View File

@ -1430,6 +1430,34 @@ array var(
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}
array std(
const array& a,
bool keepdims,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return std(a, axes, keepdims, ddof, to_stream(s));
}
array std(
const array& a,
const std::vector<int>& axes,
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
return sqrt(var(a, axes, keepdims, ddof, s), s);
}
array std(
const array& a,
int axis,
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {} */) {
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
@ -2033,6 +2061,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
}
array expm1(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});
}
array sin(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);

View File

@ -507,13 +507,14 @@ array mean(
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array. */
/** Computes the variance of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {
return var(a, false, 0, to_stream(s));
}
/** Computes the var of the elements of an array along the given axes */
/** Computes the variance of the elements of an array along the given
* axes */
array var(
const array& a,
const std::vector<int>& axes,
@ -521,7 +522,8 @@ array var(
int ddof = 0,
StreamOrDevice s = {});
/** Computes the var of the elements of an array along the given axis */
/** Computes the variance of the elements of an array along the given
* axis */
array var(
const array& a,
int axis,
@ -529,6 +531,30 @@ array var(
int ddof = 0,
StreamOrDevice s = {});
/** Computes the standard deviation of the elements of an array. */
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array std(const array& a, StreamOrDevice s = {}) {
return std(a, false, 0, to_stream(s));
}
/** Computes the standard deviatoin of the elements of an array along the given
* axes */
array std(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** Computes the standard deviation of the elements of an array along the given
* axis */
array std(
const array& a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** The product of all elements of the array. */
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
inline array prod(const array& a, StreamOrDevice s = {}) {
@ -842,6 +868,9 @@ array erf(const array& a, StreamOrDevice s = {});
/** Computes the inverse error function of the elements of an array. */
array erfinv(const array& a, StreamOrDevice s = {});
/** Computes the expm1 function of the elements of an array. */
array expm1(const array& a, StreamOrDevice s = {});
/** Stop the flow of gradients. */
array stop_gradient(const array& a, StreamOrDevice s = {});

View File

@ -1239,6 +1239,34 @@ std::pair<std::vector<array>, std::vector<int>> Exp::vmap(
return {{exp(inputs[0], stream())}, axes};
}
std::vector<array> Expm1::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
return {multiply(
cotangents[0],
add(outputs[0], array(1.0f, outputs[0].dtype()), stream()),
stream())};
}
std::vector<array> Expm1::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(tangents[0], exp(primals[0], stream()), stream())};
}
std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{expm1(inputs[0], stream())}, axes};
}
bool FFT::is_equivalent(const Primitive& other) const {
const FFT& r_other = static_cast<const FFT&>(other);
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&

View File

@ -837,6 +837,22 @@ class Exp : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Expm1 : public UnaryPrimitive {
public:
explicit Expm1(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Expm1)
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class FFT : public UnaryPrimitive {
public:
explicit FFT(

View File

@ -772,6 +772,25 @@ void init_ops(nb::module_& m) {
Returns:
array: The exponential of ``a``.
)pbdoc");
m.def(
"expm1",
&mlx::core::expm1,
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def expm1(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise exponential minus 1.
Computes ``exp(x) - 1`` with greater precision for small ``x``.
Args:
a (array): Input array.
Returns:
array: The expm1 of ``a``.
)pbdoc");
m.def(
"erf",
&mlx::core::erf,
@ -2150,6 +2169,40 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array of variances.
)pbdoc");
m.def(
"std",
[](const array& a,
const IntOrVec& axis,
bool keepdims,
int ddof,
StreamOrDevice s) {
return mlx::core::std(
a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
},
nb::arg(),
"axis"_a = nb::none(),
"keepdims"_a = false,
"ddof"_a = 0,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def std(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the standard deviation(s) over the given axes.
Args:
a (array): Input array.
axis (int or list(int), optional): Optional axis or
axes to reduce over. If unspecified this defaults
to reducing over the entire array.
keepdims (bool, optional): Keep reduced axes as
singleton dimensions, defaults to `False`.
ddof (int, optional): The divisor to compute the variance
is ``N - ddof``, defaults to 0.
Returns:
array: The output array of standard deviations.
)pbdoc");
m.def(
"split",
[](const array& a,

View File

@ -725,6 +725,11 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.var(x, ddof=3)
self.assertEqual(out.item(), float("inf"))
def test_std(self):
x = mx.random.uniform(shape=(5, 5))
x_np = np.array(x)
self.assertAlmostEqual(mx.std(x).item(), x_np.std().item(), places=6)
def test_abs(self):
a = mx.array([-1.0, 1.0, -2.0, 3.0])
result = mx.abs(a)
@ -839,6 +844,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected))
def test_expm1(self):
a = mx.array([0, 0.5, -0.5, 5])
result = mx.expm1(a)
expected = np.expm1(a, dtype=np.float32)
self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
def test_erf(self):
inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0]
x = mx.array(inputs)

View File

@ -1092,6 +1092,20 @@ TEST_CASE("test arithmetic unary ops") {
CHECK(allclose(exp(x), expected).item<bool>());
}
// Test expm1
{
array x(-1.0f);
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(-1.0f)));
x = array(1.0f);
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
// Integer input type
x = array(1);
CHECK_EQ(expm1(x).dtype(), float32);
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
}
// Test sine
{
array x(0.0);