Add bitwise ops (#1037)

* bitwise ops

* fix tests
This commit is contained in:
Awni Hannun
2024-04-26 22:03:42 -07:00
committed by GitHub
parent 67d1894759
commit 86f495985b
17 changed files with 568 additions and 58 deletions

View File

@@ -596,7 +596,7 @@ TEST_CASE("test op vjps") {
// Test power
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{inputs[0] ^ inputs[1]};
return std::vector<array>{power(inputs[0], inputs[1])};
};
auto out = vjp(fun, {array(4.0f), array(3.0f)}, {array(1.0f)}).second;
CHECK_EQ(out[0].item<float>(), 48.0f);

View File

@@ -2308,29 +2308,26 @@ TEST_CASE("test pad") {
TEST_CASE("test power") {
CHECK_EQ(power(array(1), array(2)).item<int>(), 1);
CHECK_EQ((array(1) ^ 2).item<int>(), 1);
CHECK_EQ((1 ^ array(2)).item<int>(), 1);
CHECK_EQ((array(-1) ^ 2).item<int>(), 1);
CHECK_EQ((array(-1) ^ 3).item<int>(), -1);
CHECK_EQ((power(array(-1), array(2))).item<int>(), 1);
CHECK_EQ((power(array(-1), array(3))).item<int>(), -1);
// TODO Throws but exception not caught from calling thread
// CHECK_THROWS((x^-1).item<int>());
CHECK_EQ((array(true) ^ array(false)).item<bool>(), true);
CHECK_EQ((array(false) ^ array(false)).item<bool>(), true);
CHECK_EQ((array(true) ^ array(true)).item<bool>(), true);
CHECK_EQ((array(false) ^ array(true)).item<bool>(), false);
CHECK_EQ((power(array(true), array(false))).item<bool>(), true);
CHECK_EQ((power(array(false), array(false))).item<bool>(), true);
CHECK_EQ((power(array(true), array(true))).item<bool>(), true);
CHECK_EQ((power(array(false), array(true))).item<bool>(), false);
auto x = array(2.0f);
CHECK_EQ((x ^ 0.5).item<float>(), doctest::Approx(std::pow(2.0f, 0.5f)));
CHECK_EQ((x ^ 2.0f).item<float>(), 4.0f);
CHECK_EQ(
(power(x, array(0.5))).item<float>(),
doctest::Approx(std::pow(2.0f, 0.5f)));
CHECK_EQ(power(x, array(2.0f)).item<float>(), 4.0f);
CHECK(std::isnan((array(-1.0f) ^ 0.5).item<float>()));
CHECK(std::isnan((power(array(-1.0f), array(0.5))).item<float>()));
auto a = complex64_t{0.5, 0.5};
auto b = complex64_t{0.5, 0.5};
auto expected = std::pow(a, b);
auto out = (array(a) ^ array(b)).item<complex64_t>();
auto out = (power(array(a), array(b))).item<complex64_t>();
CHECK(abs(out.real() - expected.real()) < 1e-7);
CHECK(abs(out.imag() - expected.imag()) < 1e-7);
}
@@ -3230,4 +3227,4 @@ TEST_CASE("test meshgrid") {
expected_one = array({1, 2, 3}, {3, 1});
CHECK(array_equal(out[0], expected_zero).item<bool>());
CHECK(array_equal(out[1], expected_one).item<bool>());
}
}