mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
28f39e9038
commit
5f5770e3a2
@ -86,13 +86,14 @@ struct Sign {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
auto z = Simd<T, N>{0};
|
||||
auto o = Simd<T, N>{1};
|
||||
auto m = Simd<T, N>{-1};
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != z;
|
||||
return simd::select(x == z, z, o);
|
||||
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
||||
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
||||
} else {
|
||||
return simd::select(
|
||||
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
|
||||
return simd::select(x < z, m, simd::select(x > z, o, z));
|
||||
}
|
||||
}
|
||||
SINGLE()
|
||||
|
@ -1241,6 +1241,47 @@ TEST_CASE("test arithmetic unary ops") {
|
||||
// bool
|
||||
x = array({false, true});
|
||||
CHECK(array_equal(sign(x), x).item<bool>());
|
||||
|
||||
// uint64
|
||||
array x_uint64(
|
||||
{uint64_t(0xa11cc311cb6acd70),
|
||||
uint64_t(0x7a375ac3ebb533f3),
|
||||
uint64_t(0x734969adf9d7190c),
|
||||
uint64_t(0xb400515a4f673424)});
|
||||
array expected(
|
||||
{uint64_t(0x0000000000000001),
|
||||
uint64_t(0x0000000000000001),
|
||||
uint64_t(0x0000000000000001),
|
||||
uint64_t(0x0000000000000001)});
|
||||
CHECK(array_equal(sign(x_uint64), expected).item<bool>());
|
||||
|
||||
x_uint64 = array(
|
||||
{uint64_t(0xa11cc311cb6acd70),
|
||||
uint64_t(0x7a375ac3ebb533f3),
|
||||
uint64_t(0x734969adf9d7190c)});
|
||||
expected = array(
|
||||
{uint64_t(0x0000000000000001),
|
||||
uint64_t(0x0000000000000001),
|
||||
uint64_t(0x0000000000000001)});
|
||||
CHECK(array_equal(sign(x_uint64), expected).item<bool>());
|
||||
|
||||
x_uint64 =
|
||||
array({uint64_t(0xa11cc311cb6acd70), uint64_t(0x7a375ac3ebb533f3)});
|
||||
expected =
|
||||
array({uint64_t(0x0000000000000001), uint64_t(0x0000000000000001)});
|
||||
CHECK(array_equal(sign(x_uint64), expected).item<bool>());
|
||||
|
||||
x_uint64 = array({uint64_t(0xa11cc311cb6acd70)});
|
||||
expected = array({uint64_t(0x0000000000000001)});
|
||||
CHECK(array_equal(sign(x_uint64), expected).item<bool>());
|
||||
|
||||
x_uint64 = array({uint64_t(0xffffffffffffffff)});
|
||||
expected = array({uint64_t(0x0000000000000001)});
|
||||
CHECK(array_equal(sign(x_uint64), expected).item<bool>());
|
||||
|
||||
x_uint64 = array({uint64_t(0x0000000000000001)});
|
||||
expected = array({uint64_t(0x0000000000000001)});
|
||||
CHECK(array_equal(sign(x_uint64), expected).item<bool>());
|
||||
}
|
||||
|
||||
constexpr float neginf = -std::numeric_limits<float>::infinity();
|
||||
|
Loading…
Reference in New Issue
Block a user