From f3dfa36a3aa67dfc4488996bf7f218f976bef9aa Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 11 Dec 2024 07:47:18 -0800 Subject: [PATCH] Fix x86 tests (#1691) * fix x86 tests * comment --- mlx/backend/common/CMakeLists.txt | 2 +- mlx/backend/common/make_compiled_preamble.sh | 11 ++++++----- tests/autograd_tests.cpp | 4 ++-- tests/ops_tests.cpp | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 4fca2274e..92f6ab7da 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -10,7 +10,7 @@ add_custom_command( COMMAND /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} - ${PROJECT_SOURCE_DIR} ${CLANG} + ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} DEPENDS make_compiled_preamble.sh compiled_preamble.h ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh index 149dc2886..5f1019e21 100644 --- a/mlx/backend/common/make_compiled_preamble.sh +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -10,15 +10,16 @@ OUTPUT_FILE=$1 GCC=$2 SRCDIR=$3 CLANG=$4 +ARCH=$5 if [ "$CLANG" = "TRUE" ]; then read -r -d '' INCLUDES <<- EOM - #include - #include - #include - #include +#include +#include +#include +#include EOM -CC_FLAGS="" +CC_FLAGS="-arch ${ARCH}" else CC_FLAGS="-std=c++17" fi diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index e5e9a270a..e87b0ca06 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -437,10 +437,10 @@ TEST_CASE("test op vjps") { // Test erf { auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f)); - CHECK_EQ(out.second.item(), 0.0f); + CHECK_EQ(out.second.item(), doctest::Approx(0.0f)); out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f)); - CHECK_EQ(out.second.item(), 0.0f); + CHECK_EQ(out.second.item(), doctest::Approx(0.0f)); out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f)); CHECK_EQ(out.second.item(), static_cast(M_2_SQRTPI)); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index a3638cfec..ae830262b 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1227,7 +1227,7 @@ TEST_CASE("test arithmetic unary ops") { CHECK(array_equal(exp(array({})), array({})).item()); x = array(neginf); - CHECK_EQ(exp(x).item(), 0.0f); + CHECK_EQ(exp(x).item(), doctest::Approx(0.0f)); // Integer input type x = array(2);