diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 4fca2274e..92f6ab7da 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -10,7 +10,7 @@ add_custom_command( COMMAND /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} - ${PROJECT_SOURCE_DIR} ${CLANG} + ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} DEPENDS make_compiled_preamble.sh compiled_preamble.h ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh index 149dc2886..5f1019e21 100644 --- a/mlx/backend/common/make_compiled_preamble.sh +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -10,15 +10,16 @@ OUTPUT_FILE=$1 GCC=$2 SRCDIR=$3 CLANG=$4 +ARCH=$5 if [ "$CLANG" = "TRUE" ]; then read -r -d '' INCLUDES <<- EOM - #include - #include - #include - #include +#include +#include +#include +#include EOM -CC_FLAGS="" +CC_FLAGS="-arch ${ARCH}" else CC_FLAGS="-std=c++17" fi diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index e5e9a270a..e87b0ca06 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -437,10 +437,10 @@ TEST_CASE("test op vjps") { // Test erf { auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f)); - CHECK_EQ(out.second.item(), 0.0f); + CHECK_EQ(out.second.item(), doctest::Approx(0.0f)); out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f)); - CHECK_EQ(out.second.item(), 0.0f); + CHECK_EQ(out.second.item(), doctest::Approx(0.0f)); out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f)); CHECK_EQ(out.second.item(), static_cast(M_2_SQRTPI)); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index a3638cfec..ae830262b 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1227,7 +1227,7 @@ TEST_CASE("test arithmetic unary ops") { CHECK(array_equal(exp(array({})), array({})).item()); x = array(neginf); - CHECK_EQ(exp(x).item(), 0.0f); + CHECK_EQ(exp(x).item(), doctest::Approx(0.0f)); // Integer input type x = array(2);