Fix x86 tests (#1691)

* fix x86 tests

* comment
This commit is contained in:
Awni Hannun 2024-12-11 07:47:18 -08:00 committed by GitHub
parent 4f9b60dd53
commit f3dfa36a3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 10 additions and 9 deletions

View File

@ -10,7 +10,7 @@ add_custom_command(
COMMAND COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG} ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.sh DEPENDS make_compiled_preamble.sh
compiled_preamble.h compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h

View File

@ -10,15 +10,16 @@ OUTPUT_FILE=$1
GCC=$2 GCC=$2
SRCDIR=$3 SRCDIR=$3
CLANG=$4 CLANG=$4
ARCH=$5
if [ "$CLANG" = "TRUE" ]; then if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM read -r -d '' INCLUDES <<- EOM
#include <cmath> #include <cmath>
#include <complex> #include <complex>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
EOM EOM
CC_FLAGS="" CC_FLAGS="-arch ${ARCH}"
else else
CC_FLAGS="-std=c++17" CC_FLAGS="-std=c++17"
fi fi

View File

@ -437,10 +437,10 @@ TEST_CASE("test op vjps") {
// Test erf // Test erf
{ {
auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f)); auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f));
CHECK_EQ(out.second.item<float>(), 0.0f); CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));
out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f)); out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f));
CHECK_EQ(out.second.item<float>(), 0.0f); CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));
out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f)); out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f));
CHECK_EQ(out.second.item<float>(), static_cast<float>(M_2_SQRTPI)); CHECK_EQ(out.second.item<float>(), static_cast<float>(M_2_SQRTPI));

View File

@ -1227,7 +1227,7 @@ TEST_CASE("test arithmetic unary ops") {
CHECK(array_equal(exp(array({})), array({})).item<bool>()); CHECK(array_equal(exp(array({})), array({})).item<bool>());
x = array(neginf); x = array(neginf);
CHECK_EQ(exp(x).item<float>(), 0.0f); CHECK_EQ(exp(x).item<float>(), doctest::Approx(0.0f));
// Integer input type // Integer input type
x = array(2); x = array(2);