diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index a10b126af..b5a89e308 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -5,13 +5,13 @@ Operations .. currentmodule:: mlx.core -.. autosummary:: +.. autosummary:: :toctree: _autosummary abs add all - allclose + allclose any arange arccos @@ -51,6 +51,7 @@ Operations erf erfinv exp + expm1 expand_dims eye flatten @@ -117,6 +118,7 @@ Operations square squeeze stack + std stop_gradient subtract sum diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 4ca3da1b8..d8fa52b19 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -310,6 +310,19 @@ void Exp::eval_cpu(const std::vector& inputs, array& out) { } } +void Expm1::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); + auto size = in.data_size(); + vvexpm1f( + out.data(), in.data(), reinterpret_cast(&size)); + } else { + eval(inputs, out); + } +} + void Full::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 83fb86da9..219d52ad3 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -57,6 +57,7 @@ DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) DEFAULT(Exp) +DEFAULT(Expm1) DEFAULT(FFT) DEFAULT(Floor) DEFAULT(Full) diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index b5b0953b2..0aff1de37 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -241,6 +241,13 @@ struct Exp { } }; +struct Expm1 { + template + T operator()(T x) { + return expm1(x); + }; +}; + struct Floor { template T operator()(T x) { diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index df14ca33b..08b0775c8 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -359,6 +359,18 @@ void Exp::eval(const std::vector& inputs, array& out) { } } +void Expm1::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (issubdtype(out.dtype(), inexact)) { + unary_fp(in, out, detail::Expm1()); + } else { + throw std::invalid_argument( + "[expm1] Cannot exponentiate elements in array" + " with non floating point type."); + } +} + void Floor::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 64ee1889c..e8ca1356c 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -7,6 +7,7 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h + ${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h ${CMAKE_CURRENT_SOURCE_DIR}/unary.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h diff --git a/mlx/backend/metal/kernels/expm1f.h b/mlx/backend/metal/kernels/expm1f.h new file mode 100644 index 000000000..b649dd99a --- /dev/null +++ b/mlx/backend/metal/kernels/expm1f.h @@ -0,0 +1,89 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +// Original license copied below: +// Copyright (c) 2015-2023 Norbert Juffa +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* Compute exponential base e minus 1. Maximum ulp error = 0.997458 + + i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. + Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). + With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, + when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. + + NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) +*/ +float expm1f_scaled_unchecked(float a, float b) { + float f, j, r, s, t, u, v, x, y; + int i; + + // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) + j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 + j = j - 12582912.0f; // 0x1.8p23 + i = (int)j; + f = fma(j, -6.93145752e-1f, a); + + // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] + s = f * f; + if (a == 0.0f) + s = a; // ensure -0 is passed through + // err = 0.997458 ulp1 = 11081805 + r = 1.97350979e-4f; // 0x1.9de000p-13 + r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 + r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 + r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 + r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 + r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 + u = (j == 1) ? (f + 0.5f) : f; + v = fma(r, s, u); + s = 0.5f * b; + t = ldexp(s, i); + y = t - s; + x = (t - y) - s; // double-float canonicalization of difference + r = fma(v, t, x) + y; + r = r + r; + if (j == 0) + r = v; + if (j == 1) + r = v + v; + return r; +} + +/* Compute exponential base e minus 1. max ulp err = 0.99746 */ +float expm1f(float a) { + float r; + + r = expm1f_scaled_unchecked(a, 1.0f); + /* handle severe overflow and underflow */ + if (abs(a - 1.0f) > 88.0f) { + r = fma(r, r, -1.0f); + } + return r; +} diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index e0d80ab10..dd380f2c5 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -7,6 +7,7 @@ #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/expm1f.h" #include "mlx/backend/metal/kernels/utils.h" namespace { @@ -183,6 +184,13 @@ struct Exp { } }; +struct Expm1 { + template + T operator()(T x) { + return static_cast(expm1f(static_cast(x))); + }; +}; + struct Floor { template T operator()(T x) { diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 154db0520..69b5580c5 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -71,6 +71,7 @@ instantiate_unary_types(ceil, Ceil) instantiate_unary_float(cos, Cos) instantiate_unary_float(cosh, Cosh) instantiate_unary_float(exp, Exp) +instantiate_unary_float(expm1, Expm1) instantiate_unary_types(floor, Floor) instantiate_unary_float(log, Log) instantiate_unary_float(log2, Log2) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 13a991f43..a1497a34f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -615,6 +615,10 @@ void Exp::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "exp"); } +void Expm1::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "expm1"); +} + void Full::eval_gpu(const std::vector& inputs, array& out) { auto in = inputs[0]; CopyType ctype; diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 8222c59af..211ccbf9d 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -49,6 +49,7 @@ NO_GPU(Equal) NO_GPU(Erf) NO_GPU(ErfInv) NO_GPU(Exp) +NO_GPU(Expm1) NO_GPU(FFT) NO_GPU(Floor) NO_GPU(Full) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 3a00a01b7..d187c1e10 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -32,7 +32,7 @@ bool is_unary(const Primitive& p) { typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) || - typeid(p) == typeid(Tanh)); + typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1)); } bool is_binary(const Primitive& p) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a06bb4637..9f997ebaf 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1430,6 +1430,34 @@ array var( return var(a, std::vector{axis}, keepdims, ddof, to_stream(s)); } +array std( + const array& a, + bool keepdims, + int ddof /* = 0*/, + StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return std(a, axes, keepdims, ddof, to_stream(s)); +} + +array std( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + int ddof /* = 0*/, + StreamOrDevice s /* = {}*/) { + return sqrt(var(a, axes, keepdims, ddof, s), s); +} + +array std( + const array& a, + int axis, + bool keepdims /* = false */, + int ddof /* = 0*/, + StreamOrDevice s /* = {} */) { + return std(a, std::vector{axis}, keepdims, ddof, to_stream(s)); +} + array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { std::vector axes(a.ndim()); std::iota(axes.begin(), axes.end(), 0); @@ -2033,6 +2061,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) { return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } +array expm1(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return array( + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); +} + array sin(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); diff --git a/mlx/ops.h b/mlx/ops.h index e59ceaddb..0118a2f35 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -507,13 +507,14 @@ array mean( bool keepdims = false, StreamOrDevice s = {}); -/** Computes the mean of the elements of an array. */ +/** Computes the variance of the elements of an array. */ array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); inline array var(const array& a, StreamOrDevice s = {}) { return var(a, false, 0, to_stream(s)); } -/** Computes the var of the elements of an array along the given axes */ +/** Computes the variance of the elements of an array along the given + * axes */ array var( const array& a, const std::vector& axes, @@ -521,7 +522,8 @@ array var( int ddof = 0, StreamOrDevice s = {}); -/** Computes the var of the elements of an array along the given axis */ +/** Computes the variance of the elements of an array along the given + * axis */ array var( const array& a, int axis, @@ -529,6 +531,30 @@ array var( int ddof = 0, StreamOrDevice s = {}); +/** Computes the standard deviation of the elements of an array. */ +array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); +inline array std(const array& a, StreamOrDevice s = {}) { + return std(a, false, 0, to_stream(s)); +} + +/** Computes the standard deviatoin of the elements of an array along the given + * axes */ +array std( + const array& a, + const std::vector& axes, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the standard deviation of the elements of an array along the given + * axis */ +array std( + const array& a, + int axis, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + /** The product of all elements of the array. */ array prod(const array& a, bool keepdims, StreamOrDevice s = {}); inline array prod(const array& a, StreamOrDevice s = {}) { @@ -842,6 +868,9 @@ array erf(const array& a, StreamOrDevice s = {}); /** Computes the inverse error function of the elements of an array. */ array erfinv(const array& a, StreamOrDevice s = {}); +/** Computes the expm1 function of the elements of an array. */ +array expm1(const array& a, StreamOrDevice s = {}); + /** Stop the flow of gradients. */ array stop_gradient(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 09a456c47..31b64a965 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1239,6 +1239,34 @@ std::pair, std::vector> Exp::vmap( return {{exp(inputs[0], stream())}, axes}; } +std::vector Expm1::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + return {multiply( + cotangents[0], + add(outputs[0], array(1.0f, outputs[0].dtype()), stream()), + stream())}; +} + +std::vector Expm1::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {multiply(tangents[0], exp(primals[0], stream()), stream())}; +} + +std::pair, std::vector> Expm1::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{expm1(inputs[0], stream())}, axes}; +} + bool FFT::is_equivalent(const Primitive& other) const { const FFT& r_other = static_cast(other); return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ && diff --git a/mlx/primitives.h b/mlx/primitives.h index fd05b23f0..ff2f01bc1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -837,6 +837,22 @@ class Exp : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Expm1 : public UnaryPrimitive { + public: + explicit Expm1(Stream stream) : UnaryPrimitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Expm1) + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + void eval(const std::vector& inputs, array& out); +}; + class FFT : public UnaryPrimitive { public: explicit FFT( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index d8e31daf0..1c08b55a5 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -772,6 +772,25 @@ void init_ops(nb::module_& m) { Returns: array: The exponential of ``a``. )pbdoc"); + m.def( + "expm1", + &mlx::core::expm1, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def expm1(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise exponential minus 1. + + Computes ``exp(x) - 1`` with greater precision for small ``x``. + + Args: + a (array): Input array. + + Returns: + array: The expm1 of ``a``. + )pbdoc"); m.def( "erf", &mlx::core::erf, @@ -2150,6 +2169,40 @@ void init_ops(nb::module_& m) { Returns: array: The output array of variances. )pbdoc"); + m.def( + "std", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + int ddof, + StreamOrDevice s) { + return mlx::core::std( + a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); + }, + nb::arg(), + "axis"_a = nb::none(), + "keepdims"_a = false, + "ddof"_a = 0, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def std(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the standard deviation(s) over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + ddof (int, optional): The divisor to compute the variance + is ``N - ddof``, defaults to 0. + + Returns: + array: The output array of standard deviations. + )pbdoc"); m.def( "split", [](const array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index e2bd749bc..4906b5b55 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -725,6 +725,11 @@ class TestOps(mlx_tests.MLXTestCase): out = mx.var(x, ddof=3) self.assertEqual(out.item(), float("inf")) + def test_std(self): + x = mx.random.uniform(shape=(5, 5)) + x_np = np.array(x) + self.assertAlmostEqual(mx.std(x).item(), x_np.std().item(), places=6) + def test_abs(self): a = mx.array([-1.0, 1.0, -2.0, 3.0]) result = mx.abs(a) @@ -839,6 +844,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + def test_expm1(self): + a = mx.array([0, 0.5, -0.5, 5]) + result = mx.expm1(a) + expected = np.expm1(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5)) + def test_erf(self): inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0] x = mx.array(inputs) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index fd00d339d..545526c2e 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1092,6 +1092,20 @@ TEST_CASE("test arithmetic unary ops") { CHECK(allclose(exp(x), expected).item()); } + // Test expm1 + { + array x(-1.0f); + CHECK_EQ(expm1(x).item(), doctest::Approx(std::expm1(-1.0f))); + + x = array(1.0f); + CHECK_EQ(expm1(x).item(), doctest::Approx(std::expm1(1.0f))); + + // Integer input type + x = array(1); + CHECK_EQ(expm1(x).dtype(), float32); + CHECK_EQ(expm1(x).item(), doctest::Approx(std::expm1(1.0f))); + } + // Test sine { array x(0.0);