mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
4f9b60dd53
commit
f3dfa36a3a
@ -10,7 +10,7 @@ add_custom_command(
|
|||||||
COMMAND
|
COMMAND
|
||||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||||
DEPENDS make_compiled_preamble.sh
|
DEPENDS make_compiled_preamble.sh
|
||||||
compiled_preamble.h
|
compiled_preamble.h
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||||
|
@ -10,15 +10,16 @@ OUTPUT_FILE=$1
|
|||||||
GCC=$2
|
GCC=$2
|
||||||
SRCDIR=$3
|
SRCDIR=$3
|
||||||
CLANG=$4
|
CLANG=$4
|
||||||
|
ARCH=$5
|
||||||
|
|
||||||
if [ "$CLANG" = "TRUE" ]; then
|
if [ "$CLANG" = "TRUE" ]; then
|
||||||
read -r -d '' INCLUDES <<- EOM
|
read -r -d '' INCLUDES <<- EOM
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
EOM
|
EOM
|
||||||
CC_FLAGS=""
|
CC_FLAGS="-arch ${ARCH}"
|
||||||
else
|
else
|
||||||
CC_FLAGS="-std=c++17"
|
CC_FLAGS="-std=c++17"
|
||||||
fi
|
fi
|
||||||
|
@ -437,10 +437,10 @@ TEST_CASE("test op vjps") {
|
|||||||
// Test erf
|
// Test erf
|
||||||
{
|
{
|
||||||
auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f));
|
auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f));
|
||||||
CHECK_EQ(out.second.item<float>(), 0.0f);
|
CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));
|
||||||
|
|
||||||
out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f));
|
out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f));
|
||||||
CHECK_EQ(out.second.item<float>(), 0.0f);
|
CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));
|
||||||
|
|
||||||
out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f));
|
out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f));
|
||||||
CHECK_EQ(out.second.item<float>(), static_cast<float>(M_2_SQRTPI));
|
CHECK_EQ(out.second.item<float>(), static_cast<float>(M_2_SQRTPI));
|
||||||
|
@ -1227,7 +1227,7 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
CHECK(array_equal(exp(array({})), array({})).item<bool>());
|
CHECK(array_equal(exp(array({})), array({})).item<bool>());
|
||||||
|
|
||||||
x = array(neginf);
|
x = array(neginf);
|
||||||
CHECK_EQ(exp(x).item<float>(), 0.0f);
|
CHECK_EQ(exp(x).item<float>(), doctest::Approx(0.0f));
|
||||||
|
|
||||||
// Integer input type
|
// Integer input type
|
||||||
x = array(2);
|
x = array(2);
|
||||||
|
Loading…
Reference in New Issue
Block a user