diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index aab1e0de2..633230658 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -86,13 +86,14 @@ struct Sign { template Simd operator()(Simd x) { auto z = Simd{0}; + auto o = Simd{1}; + auto m = Simd{-1}; if constexpr (std::is_unsigned_v) { - return x != z; + return simd::select(x == z, z, o); } else if constexpr (std::is_same_v) { return simd::select(x == z, x, Simd(x / simd::abs(x))); } else { - return simd::select( - x < z, Simd{-1}, simd::select(x > z, Simd{1}, z)); + return simd::select(x < z, m, simd::select(x > z, o, z)); } } SINGLE() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index b8bc8b45e..356515702 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1241,6 +1241,47 @@ TEST_CASE("test arithmetic unary ops") { // bool x = array({false, true}); CHECK(array_equal(sign(x), x).item()); + + // uint64 + array x_uint64( + {uint64_t(0xa11cc311cb6acd70), + uint64_t(0x7a375ac3ebb533f3), + uint64_t(0x734969adf9d7190c), + uint64_t(0xb400515a4f673424)}); + array expected( + {uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array( + {uint64_t(0xa11cc311cb6acd70), + uint64_t(0x7a375ac3ebb533f3), + uint64_t(0x734969adf9d7190c)}); + expected = array( + {uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = + array({uint64_t(0xa11cc311cb6acd70), uint64_t(0x7a375ac3ebb533f3)}); + expected = + array({uint64_t(0x0000000000000001), uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array({uint64_t(0xa11cc311cb6acd70)}); + expected = array({uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array({uint64_t(0xffffffffffffffff)}); + expected = array({uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array({uint64_t(0x0000000000000001)}); + expected = array({uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); } constexpr float neginf = -std::numeric_limits::infinity();