mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
std and expm1 (#973)
* std and expm1 * actually add expm1 * fix linux * fix vjp * relax tol for linux test * Add it to the compilable primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
76e63212ff
commit
42afe27e12
@ -51,6 +51,7 @@ Operations
|
||||
erf
|
||||
erfinv
|
||||
exp
|
||||
expm1
|
||||
expand_dims
|
||||
eye
|
||||
flatten
|
||||
@ -117,6 +118,7 @@ Operations
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
std
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
|
@ -310,6 +310,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -57,6 +57,7 @@ DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
|
@ -241,6 +241,13 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -359,6 +359,18 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[expm1] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -7,6 +7,7 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
|
89
mlx/backend/metal/kernels/expm1f.h
Normal file
89
mlx/backend/metal/kernels/expm1f.h
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
// Original license copied below:
|
||||
// Copyright (c) 2015-2023 Norbert Juffa
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions
|
||||
// are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
|
||||
|
||||
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
|
||||
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
|
||||
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
|
||||
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
|
||||
|
||||
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
|
||||
*/
|
||||
float expm1f_scaled_unchecked(float a, float b) {
|
||||
float f, j, r, s, t, u, v, x, y;
|
||||
int i;
|
||||
|
||||
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
|
||||
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
|
||||
j = j - 12582912.0f; // 0x1.8p23
|
||||
i = (int)j;
|
||||
f = fma(j, -6.93145752e-1f, a);
|
||||
|
||||
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
|
||||
s = f * f;
|
||||
if (a == 0.0f)
|
||||
s = a; // ensure -0 is passed through
|
||||
// err = 0.997458 ulp1 = 11081805
|
||||
r = 1.97350979e-4f; // 0x1.9de000p-13
|
||||
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
|
||||
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
|
||||
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
|
||||
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
|
||||
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
|
||||
u = (j == 1) ? (f + 0.5f) : f;
|
||||
v = fma(r, s, u);
|
||||
s = 0.5f * b;
|
||||
t = ldexp(s, i);
|
||||
y = t - s;
|
||||
x = (t - y) - s; // double-float canonicalization of difference
|
||||
r = fma(v, t, x) + y;
|
||||
r = r + r;
|
||||
if (j == 0)
|
||||
r = v;
|
||||
if (j == 1)
|
||||
r = v + v;
|
||||
return r;
|
||||
}
|
||||
|
||||
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
|
||||
float expm1f(float a) {
|
||||
float r;
|
||||
|
||||
r = expm1f_scaled_unchecked(a, 1.0f);
|
||||
/* handle severe overflow and underflow */
|
||||
if (abs(a - 1.0f) > 88.0f) {
|
||||
r = fma(r, r, -1.0f);
|
||||
}
|
||||
return r;
|
||||
}
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
namespace {
|
||||
@ -183,6 +184,13 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(expm1f(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -71,6 +71,7 @@ instantiate_unary_types(ceil, Ceil)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_float(expm1, Expm1)
|
||||
instantiate_unary_types(floor, Floor)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
|
@ -615,6 +615,10 @@ void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "exp");
|
||||
}
|
||||
|
||||
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "expm1");
|
||||
}
|
||||
|
||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
CopyType ctype;
|
||||
|
@ -49,6 +49,7 @@ NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Full)
|
||||
|
@ -32,7 +32,7 @@ bool is_unary(const Primitive& p) {
|
||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||
typeid(p) == typeid(Tanh));
|
||||
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
|
||||
}
|
||||
|
||||
bool is_binary(const Primitive& p) {
|
||||
|
35
mlx/ops.cpp
35
mlx/ops.cpp
@ -1430,6 +1430,34 @@ array var(
|
||||
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
bool keepdims,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return std(a, axes, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
return sqrt(var(a, axes, keepdims, ddof, s), s);
|
||||
}
|
||||
|
||||
array std(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims /* = false */,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
|
||||
}
|
||||
|
||||
array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
@ -2033,6 +2061,13 @@ array exp(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array expm1(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
return array(
|
||||
a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
|
35
mlx/ops.h
35
mlx/ops.h
@ -507,13 +507,14 @@ array mean(
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the mean of the elements of an array. */
|
||||
/** Computes the variance of the elements of an array. */
|
||||
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array var(const array& a, StreamOrDevice s = {}) {
|
||||
return var(a, false, 0, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the var of the elements of an array along the given axes */
|
||||
/** Computes the variance of the elements of an array along the given
|
||||
* axes */
|
||||
array var(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
@ -521,7 +522,8 @@ array var(
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the var of the elements of an array along the given axis */
|
||||
/** Computes the variance of the elements of an array along the given
|
||||
* axis */
|
||||
array var(
|
||||
const array& a,
|
||||
int axis,
|
||||
@ -529,6 +531,30 @@ array var(
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the standard deviation of the elements of an array. */
|
||||
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array std(const array& a, StreamOrDevice s = {}) {
|
||||
return std(a, false, 0, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the standard deviatoin of the elements of an array along the given
|
||||
* axes */
|
||||
array std(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims = false,
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the standard deviation of the elements of an array along the given
|
||||
* axis */
|
||||
array std(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
int ddof = 0,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** The product of all elements of the array. */
|
||||
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||
inline array prod(const array& a, StreamOrDevice s = {}) {
|
||||
@ -842,6 +868,9 @@ array erf(const array& a, StreamOrDevice s = {});
|
||||
/** Computes the inverse error function of the elements of an array. */
|
||||
array erfinv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Computes the expm1 function of the elements of an array. */
|
||||
array expm1(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Stop the flow of gradients. */
|
||||
array stop_gradient(const array& a, StreamOrDevice s = {});
|
||||
|
||||
|
@ -1239,6 +1239,34 @@ std::pair<std::vector<array>, std::vector<int>> Exp::vmap(
|
||||
return {{exp(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Expm1::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
return {multiply(
|
||||
cotangents[0],
|
||||
add(outputs[0], array(1.0f, outputs[0].dtype()), stream()),
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Expm1::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {multiply(tangents[0], exp(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{expm1(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
bool FFT::is_equivalent(const Primitive& other) const {
|
||||
const FFT& r_other = static_cast<const FFT&>(other);
|
||||
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
||||
|
@ -837,6 +837,22 @@ class Exp : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Expm1 : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Expm1(Stream stream) : UnaryPrimitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Expm1)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class FFT : public UnaryPrimitive {
|
||||
public:
|
||||
explicit FFT(
|
||||
|
@ -772,6 +772,25 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The exponential of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"expm1",
|
||||
&mlx::core::expm1,
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def expm1(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise exponential minus 1.
|
||||
|
||||
Computes ``exp(x) - 1`` with greater precision for small ``x``.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The expm1 of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"erf",
|
||||
&mlx::core::erf,
|
||||
@ -2150,6 +2169,40 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The output array of variances.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"std",
|
||||
[](const array& a,
|
||||
const IntOrVec& axis,
|
||||
bool keepdims,
|
||||
int ddof,
|
||||
StreamOrDevice s) {
|
||||
return mlx::core::std(
|
||||
a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
|
||||
},
|
||||
nb::arg(),
|
||||
"axis"_a = nb::none(),
|
||||
"keepdims"_a = false,
|
||||
"ddof"_a = 0,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def std(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the standard deviation(s) over the given axes.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
axis (int or list(int), optional): Optional axis or
|
||||
axes to reduce over. If unspecified this defaults
|
||||
to reducing over the entire array.
|
||||
keepdims (bool, optional): Keep reduced axes as
|
||||
singleton dimensions, defaults to `False`.
|
||||
ddof (int, optional): The divisor to compute the variance
|
||||
is ``N - ddof``, defaults to 0.
|
||||
|
||||
Returns:
|
||||
array: The output array of standard deviations.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"split",
|
||||
[](const array& a,
|
||||
|
@ -725,6 +725,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.var(x, ddof=3)
|
||||
self.assertEqual(out.item(), float("inf"))
|
||||
|
||||
def test_std(self):
|
||||
x = mx.random.uniform(shape=(5, 5))
|
||||
x_np = np.array(x)
|
||||
self.assertAlmostEqual(mx.std(x).item(), x_np.std().item(), places=6)
|
||||
|
||||
def test_abs(self):
|
||||
a = mx.array([-1.0, 1.0, -2.0, 3.0])
|
||||
result = mx.abs(a)
|
||||
@ -839,6 +844,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
def test_expm1(self):
|
||||
a = mx.array([0, 0.5, -0.5, 5])
|
||||
result = mx.expm1(a)
|
||||
expected = np.expm1(a, dtype=np.float32)
|
||||
|
||||
self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
|
||||
|
||||
def test_erf(self):
|
||||
inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0]
|
||||
x = mx.array(inputs)
|
||||
|
@ -1092,6 +1092,20 @@ TEST_CASE("test arithmetic unary ops") {
|
||||
CHECK(allclose(exp(x), expected).item<bool>());
|
||||
}
|
||||
|
||||
// Test expm1
|
||||
{
|
||||
array x(-1.0f);
|
||||
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(-1.0f)));
|
||||
|
||||
x = array(1.0f);
|
||||
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
|
||||
|
||||
// Integer input type
|
||||
x = array(1);
|
||||
CHECK_EQ(expm1(x).dtype(), float32);
|
||||
CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
|
||||
}
|
||||
|
||||
// Test sine
|
||||
{
|
||||
array x(0.0);
|
||||
|
Loading…
Reference in New Issue
Block a user