mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Refactor CPU compile preamble (#708)
* refactor cpu preamble * fix include order * fix some issues' * fixes for linux * try to fix includes * add back warning suppression * more linux fixes
This commit is contained in:
parent
0925af43b0
commit
1a4f4c5ea6
@ -81,11 +81,8 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||||
set_unary_output_data(in, out);
|
set_unary_output_data(in, out);
|
||||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
||||||
} else if (is_unsigned(in.dtype())) {
|
|
||||||
// No-op for unsigned types
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
} else {
|
} else {
|
||||||
unary(in, out, AbsOp());
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,3 +1,33 @@
|
|||||||
|
|
||||||
|
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
|
set(CLANG TRUE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT compiled_preamble.cpp
|
||||||
|
COMMAND /bin/bash
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||||
|
${CMAKE_CXX_COMPILER}
|
||||||
|
${CMAKE_SOURCE_DIR}
|
||||||
|
${CLANG}
|
||||||
|
|
||||||
|
DEPENDS make_compiled_preamble.sh
|
||||||
|
compiled_preamble.h
|
||||||
|
${CMAKE_SOURCE_DIR}/mlx/types/half_types.h
|
||||||
|
${CMAKE_SOURCE_DIR}/mlx/types/fp16.h
|
||||||
|
${CMAKE_SOURCE_DIR}/mlx/types/bf16.h
|
||||||
|
${CMAKE_SOURCE_DIR}/mlx/types/complex.h
|
||||||
|
ops.h
|
||||||
|
)
|
||||||
|
|
||||||
|
add_custom_target(
|
||||||
|
cpu_compiled_preamble
|
||||||
|
DEPENDS compiled_preamble.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
add_dependencies(mlx cpu_compiled_preamble)
|
||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE
|
||||||
@ -19,4 +49,5 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||||
)
|
)
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/binary_two.h"
|
#include "mlx/backend/common/binary_two.h"
|
||||||
|
#include "mlx/backend/common/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
binary(a, b, out, detail::Add());
|
||||||
}
|
}
|
||||||
|
|
||||||
void DivMod::eval(
|
void DivMod::eval(
|
||||||
@ -135,106 +136,56 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
binary(a, b, out, detail::Divide());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RemainderFn {
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
|
||||||
T numerator,
|
|
||||||
T denominator) {
|
|
||||||
return numerator % denominator;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
|
||||||
T numerator,
|
|
||||||
T denominator) {
|
|
||||||
auto r = numerator % denominator;
|
|
||||||
if (r != 0 && (r < 0 != denominator < 0))
|
|
||||||
r += denominator;
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
|
||||||
T numerator,
|
|
||||||
T denominator) {
|
|
||||||
auto r = std::fmod(numerator, denominator);
|
|
||||||
if (r != 0 && (r < 0 != denominator < 0)) {
|
|
||||||
r += denominator;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
|
||||||
return numerator % denominator;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, RemainderFn{});
|
binary(a, b, out, detail::Remainder());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
if (equal_nan_) {
|
if (equal_nan_) {
|
||||||
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
|
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
|
||||||
return x == y || (std::isnan(x) && std::isnan(y));
|
|
||||||
});
|
|
||||||
} else {
|
} else {
|
||||||
comparison_op(
|
comparison_op(inputs[0], inputs[1], out, detail::Equal());
|
||||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(
|
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(
|
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Less::eval(const std::vector<array>& inputs, array& out) {
|
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(
|
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(
|
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto op = [](auto x, auto y) {
|
|
||||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
|
||||||
auto maxval = (x > y) ? x : y;
|
|
||||||
auto minval = (x > y) ? y : x;
|
|
||||||
return (minval == -inf || maxval == inf)
|
|
||||||
? maxval
|
|
||||||
: static_cast<decltype(x)>(
|
|
||||||
maxval + std::log1p(std::exp(minval - maxval)));
|
|
||||||
};
|
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
binary_op<float>(a, b, out, op);
|
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||||
} else if (out.dtype() == float16) {
|
} else if (out.dtype() == float16) {
|
||||||
binary_op<float16_t>(a, b, out, op);
|
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||||
} else if (out.dtype() == bfloat16) {
|
} else if (out.dtype() == bfloat16) {
|
||||||
binary_op<bfloat16_t>(a, b, out, op);
|
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream err;
|
std::ostringstream err;
|
||||||
err << "[logaddexp] Does not support " << out.dtype();
|
err << "[logaddexp] Does not support " << out.dtype();
|
||||||
@ -251,84 +202,40 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, detail::Maximum());
|
||||||
if (is_floating_point(out.dtype())) {
|
|
||||||
binary(a, b, out, [](auto x, auto y) {
|
|
||||||
if (std::isnan(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
return (x > y) ? x : y;
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
if (is_floating_point(out.dtype())) {
|
binary(a, b, out, detail::Minimum());
|
||||||
binary(a, b, out, [](auto x, auto y) {
|
|
||||||
if (std::isnan(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
return (x < y) ? x : y;
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
binary(a, b, out, detail::Multiply());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(
|
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PowerFn {
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
|
|
||||||
return std::pow(base, exp);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
|
|
||||||
if (exp < 0) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"Integers cannot be raise to negative powers");
|
|
||||||
}
|
|
||||||
T res = 1;
|
|
||||||
while (exp) {
|
|
||||||
if (exp & 1) {
|
|
||||||
res *= base;
|
|
||||||
}
|
|
||||||
exp >>= 1;
|
|
||||||
base *= base;
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void Power::eval(const std::vector<array>& inputs, array& out) {
|
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, PowerFn{});
|
binary(a, b, out, detail::Power());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
binary(a, b, out, detail::Subtract());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -178,7 +178,13 @@ void* compile(
|
|||||||
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
||||||
<< source_file_path << " -o " << shared_lib_path;
|
<< source_file_path << " -o " << shared_lib_path;
|
||||||
std::string build_command_str = build_command.str();
|
std::string build_command_str = build_command.str();
|
||||||
system(build_command_str.c_str());
|
auto return_code = system(build_command_str.c_str());
|
||||||
|
if (return_code) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
|
||||||
|
<< " with error code " << return_code << "." << std::endl;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
@ -421,7 +427,7 @@ void Compiled::eval_cpu(
|
|||||||
// If it doesn't exist, compile it
|
// If it doesn't exist, compile it
|
||||||
if (fn_ptr == nullptr) {
|
if (fn_ptr == nullptr) {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << preamble << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,11 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
/* Approximation to the inverse error function.
|
|
||||||
* Based on code from:
|
|
||||||
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
|
||||||
*/
|
|
||||||
float erfinv(float a);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
34
mlx/backend/common/make_compiled_preamble.sh
Normal file
34
mlx/backend/common/make_compiled_preamble.sh
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# This script generates a C++ function that provides the CPU
|
||||||
|
# code for use with kernel generation.
|
||||||
|
#
|
||||||
|
# Copyright © 2023-24 Apple Inc.
|
||||||
|
|
||||||
|
|
||||||
|
OUTPUT_FILE=$1
|
||||||
|
GCC=$2
|
||||||
|
SRCDIR=$3
|
||||||
|
CLANG=$4
|
||||||
|
|
||||||
|
if [ $CLANG = "TRUE" ]; then
|
||||||
|
read -r -d '' INCLUDES <<- EOM
|
||||||
|
#include <cmath>
|
||||||
|
#include <complex>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
EOM
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
|
||||||
|
|
||||||
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
|
const char* get_kernel_preamble() {
|
||||||
|
return R"preamble(
|
||||||
|
$INCLUDES
|
||||||
|
$CONTENT
|
||||||
|
using namespace mlx::core::detail;
|
||||||
|
)preamble";
|
||||||
|
}
|
||||||
|
EOF
|
591
mlx/backend/common/ops.h
Normal file
591
mlx/backend/common/ops.h
Normal file
@ -0,0 +1,591 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <cmath>
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
|
namespace mlx::core::detail {
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
int i;
|
||||||
|
float f;
|
||||||
|
} IntOrFloat;
|
||||||
|
|
||||||
|
inline float fast_exp(float x) {
|
||||||
|
if (x == -std::numeric_limits<float>::infinity()) {
|
||||||
|
return 0.0f;
|
||||||
|
} else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
x *= 1.442695; // multiply with log_2(e)
|
||||||
|
float ipart, fpart;
|
||||||
|
IntOrFloat epart;
|
||||||
|
x = std::max(-80.f, std::min(x, 80.f));
|
||||||
|
ipart = std::floor(x + 0.5);
|
||||||
|
fpart = x - ipart;
|
||||||
|
|
||||||
|
x = 1.535336188319500e-4f;
|
||||||
|
x = x * fpart + 1.339887440266574e-3f;
|
||||||
|
x = x * fpart + 9.618437357674640e-3f;
|
||||||
|
x = x * fpart + 5.550332471162809e-2f;
|
||||||
|
x = x * fpart + 2.402264791363012e-1f;
|
||||||
|
x = x * fpart + 6.931472028550421e-1f;
|
||||||
|
x = x * fpart + 1.000000000000000f;
|
||||||
|
|
||||||
|
// generate 2**ipart in the floating point representation using integer
|
||||||
|
// bitshifting
|
||||||
|
epart.i = (int(ipart) + 127) << 23;
|
||||||
|
|
||||||
|
return epart.f * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float fast_erf(float a) {
|
||||||
|
float r, s, t, u;
|
||||||
|
t = std::abs(a);
|
||||||
|
s = a * a;
|
||||||
|
if (t > 0.927734375f) {
|
||||||
|
// maximum error 0.99527 ulp
|
||||||
|
r = std::fma(
|
||||||
|
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
||||||
|
u = std::fma(
|
||||||
|
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
||||||
|
r = std::fma(r, s, u);
|
||||||
|
r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
||||||
|
r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
||||||
|
r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
||||||
|
r = std::fma(r, t, -t);
|
||||||
|
// TODO, replace with expm1 when implemented
|
||||||
|
r = 1.0f - std::exp(r);
|
||||||
|
r = std::copysign(r, a);
|
||||||
|
} else {
|
||||||
|
// maximum error 0.98929 ulp
|
||||||
|
r = -5.96761703e-4f; // -0x1.38e000p-11
|
||||||
|
r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
||||||
|
r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
||||||
|
r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
||||||
|
r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
||||||
|
r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
||||||
|
r = std::fma(r, a, a);
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float fast_erfinv(float a) {
|
||||||
|
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||||
|
t = std::log(t);
|
||||||
|
float p;
|
||||||
|
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||||
|
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||||
|
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||||
|
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||||
|
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||||
|
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||||
|
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||||
|
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||||
|
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||||
|
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||||
|
} else { // maximum ulp error = 2.35002
|
||||||
|
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||||
|
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||||
|
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||||
|
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||||
|
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||||
|
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||||
|
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||||
|
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||||
|
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||||
|
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||||
|
}
|
||||||
|
return a * p;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Abs {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::abs(x);
|
||||||
|
};
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
bool operator()(bool x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCos {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::acos(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCosh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::acosh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSin {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::asin(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSinh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::asinh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTan {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::atan(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTanh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::atanh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Ceil {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::ceil(x);
|
||||||
|
};
|
||||||
|
int8_t operator()(int8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
int16_t operator()(int16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
int32_t operator()(int32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
int64_t operator()(int64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
bool operator()(bool x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cos {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::cos(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cosh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::cosh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Erf {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return static_cast<T>(fast_erf(static_cast<float>(x)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ErfInv {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Exp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return fast_exp(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return std::exp(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Floor {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::floor(x);
|
||||||
|
};
|
||||||
|
int8_t operator()(int8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
int16_t operator()(int16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
int32_t operator()(int32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
int64_t operator()(int64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
bool operator()(bool x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::log(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log2 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::log2(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log10 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::log10(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log1p {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return log1p(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalNot {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return !x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Negative {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return -x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Round {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::rint(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {std::rint(x.real()), std::rint(x.imag())};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sigmoid {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
auto one = static_cast<decltype(x)>(1.0);
|
||||||
|
return one / (one + fast_exp(-x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sign {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return (x > T(0)) - (x < T(0));
|
||||||
|
}
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x != 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sin {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::sin(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sinh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::sinh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Square {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return x * x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sqrt {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::sqrt(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Rsqrt {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tan {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::tan(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tanh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::tanh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Add {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x + y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Divide {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x / y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Remainder {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
return numerator % denominator;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
auto r = numerator % denominator;
|
||||||
|
if (r != 0 && (r < 0 != denominator < 0))
|
||||||
|
r += denominator;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
auto r = std::fmod(numerator, denominator);
|
||||||
|
if (r != 0 && (r < 0 != denominator < 0)) {
|
||||||
|
r += denominator;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
||||||
|
return numerator % denominator;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Equal {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x == y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NaNEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x == y || (std::isnan(x) && std::isnan(y));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Greater {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x > y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GreaterEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x >= y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Less {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x < y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LessEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x <= y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Maximum {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
return (x > y) ? x : y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return (x > y) ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Minimum {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
return x < y ? x : y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x < y ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogAddExp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||||
|
auto maxval = Maximum()(x, y);
|
||||||
|
auto minval = Minimum()(x, y);
|
||||||
|
return (minval == -inf || maxval == inf)
|
||||||
|
? maxval
|
||||||
|
: static_cast<decltype(x)>(
|
||||||
|
maxval + std::log1p(fast_exp(minval - maxval)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Multiply {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x * y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NotEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x != y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Power {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
return std::pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
T res = 1;
|
||||||
|
while (exp) {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base;
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Subtract {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x - y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalAnd {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x && y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalOr {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x || y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::detail
|
@ -10,7 +10,7 @@
|
|||||||
#include "mlx/backend/common/arange.h"
|
#include "mlx/backend/common/arange.h"
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/erf.h"
|
#include "mlx/backend/common/ops.h"
|
||||||
#include "mlx/backend/common/threefry.h"
|
#include "mlx/backend/common/threefry.h"
|
||||||
#include "mlx/backend/common/unary.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
@ -26,7 +26,7 @@ void Abs::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
// No-op for unsigned types
|
// No-op for unsigned types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
unary(in, out, AbsOp());
|
unary(in, out, detail::Abs());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::acos(x); });
|
unary_fp(in, out, detail::ArcCos());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[arccos] Cannot compute inverse cosine of elements in array"
|
"[arccos] Cannot compute inverse cosine of elements in array"
|
||||||
@ -50,7 +50,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::acosh(x); });
|
unary_fp(in, out, detail::ArcCosh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
|
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
|
||||||
@ -62,7 +62,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::asin(x); });
|
unary_fp(in, out, detail::ArcSin());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[arcsin] Cannot compute inverse sine of elements in array"
|
"[arcsin] Cannot compute inverse sine of elements in array"
|
||||||
@ -74,7 +74,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::asinh(x); });
|
unary_fp(in, out, detail::ArcSinh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
|
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
|
||||||
@ -86,7 +86,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::atan(x); });
|
unary_fp(in, out, detail::ArcTan());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[arctan] Cannot compute inverse tangent of elements in array"
|
"[arctan] Cannot compute inverse tangent of elements in array"
|
||||||
@ -98,7 +98,7 @@ void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::atanh(x); });
|
unary_fp(in, out, detail::ArcTanh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
|
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
|
||||||
@ -172,7 +172,7 @@ void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (not is_integral(in.dtype())) {
|
if (not is_integral(in.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::ceil(x); });
|
unary_fp(in, out, detail::Ceil());
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
@ -212,7 +212,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::cos(x); });
|
unary_fp(in, out, detail::Cos());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[cos] Cannot compute cosine of elements in array"
|
"[cos] Cannot compute cosine of elements in array"
|
||||||
@ -224,7 +224,7 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::cosh(x); });
|
unary_fp(in, out, detail::Cosh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[cosh] Cannot compute hyperbolic cosine of elements in array"
|
"[cosh] Cannot compute hyperbolic cosine of elements in array"
|
||||||
@ -256,17 +256,13 @@ void Erf::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
switch (out.dtype()) {
|
switch (out.dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
|
unary_op<float>(in, out, detail::Erf());
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
unary_op<float16_t>(in, out, [](auto x) {
|
unary_op<float16_t>(in, out, detail::Erf());
|
||||||
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||||
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -280,17 +276,13 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
switch (out.dtype()) {
|
switch (out.dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
|
unary_op<float>(in, out, detail::ErfInv());
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
unary_op<float16_t>(in, out, [](auto x) {
|
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||||
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||||
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -302,9 +294,8 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
unary_fp(in, out, detail::Exp());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[exp] Cannot exponentiate elements in array"
|
"[exp] Cannot exponentiate elements in array"
|
||||||
@ -316,7 +307,7 @@ void Floor::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (not is_integral(in.dtype())) {
|
if (not is_integral(in.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::floor(x); });
|
unary_fp(in, out, detail::Floor());
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
@ -344,13 +335,13 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
case Base::e:
|
case Base::e:
|
||||||
unary_fp(in, out, [](auto x) { return std::log(x); });
|
unary_fp(in, out, detail::Log());
|
||||||
break;
|
break;
|
||||||
case Base::two:
|
case Base::two:
|
||||||
unary_fp(in, out, [](auto x) { return std::log2(x); });
|
unary_fp(in, out, detail::Log2());
|
||||||
break;
|
break;
|
||||||
case Base::ten:
|
case Base::ten:
|
||||||
unary_fp(in, out, [](auto x) { return std::log10(x); });
|
unary_fp(in, out, detail::Log10());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -364,7 +355,7 @@ void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
unary_fp(in, out, detail::Log1p());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[log1p] Cannot compute log of elements in array with"
|
"[log1p] Cannot compute log of elements in array with"
|
||||||
@ -375,27 +366,27 @@ void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
unary(in, out, [](auto x) { return !x; });
|
unary(in, out, detail::LogicalNot());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
|
binary(in1, in2, out, detail::LogicalAnd());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
|
binary(in1, in2, out, detail::LogicalOr());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
unary(in, out, [](auto x) { return -x; });
|
unary(in, out, detail::Negative());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Pad::eval(const std::vector<array>& inputs, array& out) {
|
void Pad::eval(const std::vector<array>& inputs, array& out) {
|
||||||
@ -498,7 +489,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (not is_integral(in.dtype())) {
|
if (not is_integral(in.dtype())) {
|
||||||
unary_fp(in, out, RoundOp());
|
unary_fp(in, out, detail::Round());
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
@ -509,11 +500,7 @@ void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
auto sigmoid_op = [](auto x) {
|
unary_fp(in, out, detail::Sigmoid());
|
||||||
auto one = static_cast<decltype(x)>(1.0);
|
|
||||||
return one / (one + std::exp(-x));
|
|
||||||
};
|
|
||||||
unary_fp(in, out, sigmoid_op);
|
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[sigmoid] Cannot sigmoid of elements in array with"
|
"[sigmoid] Cannot sigmoid of elements in array with"
|
||||||
@ -527,7 +514,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
unary(in, out, SignOp());
|
unary(in, out, detail::Sign());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -535,7 +522,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::sin(x); });
|
unary_fp(in, out, detail::Sin());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[sin] Cannot compute sine of elements in array"
|
"[sin] Cannot compute sine of elements in array"
|
||||||
@ -547,7 +534,7 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::sinh(x); });
|
unary_fp(in, out, detail::Sinh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[sinh] Cannot compute hyperbolic sine of elements in array"
|
"[sinh] Cannot compute hyperbolic sine of elements in array"
|
||||||
@ -656,18 +643,16 @@ void Split::eval(
|
|||||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
unary(in, out, [](auto x) { return x * x; });
|
unary(in, out, detail::Square());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (recip_) {
|
if (recip_) {
|
||||||
unary_fp(in, out, [](auto x) {
|
unary_fp(in, out, detail::Rsqrt());
|
||||||
return static_cast<decltype(x)>(1.0) / sqrt(x);
|
|
||||||
});
|
|
||||||
} else {
|
} else {
|
||||||
unary_fp(in, out, [](auto x) { return sqrt(x); });
|
unary_fp(in, out, detail::Sqrt());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -680,7 +665,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::tan(x); });
|
unary_fp(in, out, detail::Tan());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tan] Cannot compute tangent of elements in array"
|
"[tan] Cannot compute tangent of elements in array"
|
||||||
@ -692,7 +677,7 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::tanh(x); });
|
unary_fp(in, out, detail::Tanh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tanh] Cannot compute hyperbolic tangent of elements in array"
|
"[tanh] Cannot compute hyperbolic tangent of elements in array"
|
||||||
|
@ -11,59 +11,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct AbsOp {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return std::abs(x);
|
|
||||||
}
|
|
||||||
uint8_t operator()(uint8_t x) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
uint16_t operator()(uint16_t x) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
uint32_t operator()(uint32_t x) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
uint64_t operator()(uint64_t x) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
bool operator()(bool x) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SignOp {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return (x > T(0)) - (x < T(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t operator()(uint8_t x) {
|
|
||||||
return x != 0;
|
|
||||||
}
|
|
||||||
uint16_t operator()(uint16_t x) {
|
|
||||||
return x != 0;
|
|
||||||
}
|
|
||||||
uint32_t operator()(uint32_t x) {
|
|
||||||
return x != 0;
|
|
||||||
}
|
|
||||||
uint64_t operator()(uint64_t x) {
|
|
||||||
return x != 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct RoundOp {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return std::rint(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return {std::rint(x.real()), std::rint(x.imag())};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void set_unary_output_data(const array& in, array& out) {
|
void set_unary_output_data(const array& in, array& out) {
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
|
@ -38,9 +38,9 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
|||||||
inline complex64_t operator%(complex64_t a, complex64_t b) {
|
inline complex64_t operator%(complex64_t a, complex64_t b) {
|
||||||
auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
|
auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
|
||||||
auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
|
auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
|
||||||
if (real != 0 && (real < 0 != b.real() < 0))
|
if (real != 0 && ((real < 0) != (b.real() < 0)))
|
||||||
real += b.real();
|
real += b.real();
|
||||||
if (imag != 0 && (imag < 0 != b.imag() < 0))
|
if (imag != 0 && ((imag < 0) != (b.imag() < 0)))
|
||||||
imag += b.imag();
|
imag += b.imag();
|
||||||
return {real, imag};
|
return {real, imag};
|
||||||
}
|
}
|
||||||
|
@ -1002,7 +1002,7 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
CHECK_EQ(exp(x).item<float>(), 1.0);
|
CHECK_EQ(exp(x).item<float>(), 1.0);
|
||||||
|
|
||||||
x = array(2.0);
|
x = array(2.0);
|
||||||
CHECK_EQ(exp(x).item<float>(), std::exp(2.0f));
|
CHECK_EQ(exp(x).item<float>(), doctest::Approx(std::exp(2.0f)));
|
||||||
|
|
||||||
CHECK(array_equal(exp(array({})), array({})).item<bool>());
|
CHECK(array_equal(exp(array({})), array({})).item<bool>());
|
||||||
|
|
||||||
@ -1012,7 +1012,7 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
// Integer input type
|
// Integer input type
|
||||||
x = array(2);
|
x = array(2);
|
||||||
CHECK_EQ(x.dtype(), int32);
|
CHECK_EQ(x.dtype(), int32);
|
||||||
CHECK_EQ(exp(x).item<float>(), std::exp(2.0f));
|
CHECK_EQ(exp(x).item<float>(), doctest::Approx(std::exp(2.0f)));
|
||||||
|
|
||||||
// Input is irregularly strided
|
// Input is irregularly strided
|
||||||
x = broadcast_to(array(1.0f), {2, 2, 2});
|
x = broadcast_to(array(1.0f), {2, 2, 2});
|
||||||
@ -1020,7 +1020,7 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
|
|
||||||
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
|
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
|
||||||
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
|
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
|
||||||
CHECK(array_equal(exp(x), expected).item<bool>());
|
CHECK(allclose(exp(x), expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test sine
|
// Test sine
|
||||||
|
Loading…
Reference in New Issue
Block a user