mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
@@ -596,7 +596,7 @@ TEST_CASE("test op vjps") {
|
||||
// Test power
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{inputs[0] ^ inputs[1]};
|
||||
return std::vector<array>{power(inputs[0], inputs[1])};
|
||||
};
|
||||
auto out = vjp(fun, {array(4.0f), array(3.0f)}, {array(1.0f)}).second;
|
||||
CHECK_EQ(out[0].item<float>(), 48.0f);
|
||||
|
@@ -2308,29 +2308,26 @@ TEST_CASE("test pad") {
|
||||
|
||||
TEST_CASE("test power") {
|
||||
CHECK_EQ(power(array(1), array(2)).item<int>(), 1);
|
||||
CHECK_EQ((array(1) ^ 2).item<int>(), 1);
|
||||
CHECK_EQ((1 ^ array(2)).item<int>(), 1);
|
||||
CHECK_EQ((array(-1) ^ 2).item<int>(), 1);
|
||||
CHECK_EQ((array(-1) ^ 3).item<int>(), -1);
|
||||
CHECK_EQ((power(array(-1), array(2))).item<int>(), 1);
|
||||
CHECK_EQ((power(array(-1), array(3))).item<int>(), -1);
|
||||
|
||||
// TODO Throws but exception not caught from calling thread
|
||||
// CHECK_THROWS((x^-1).item<int>());
|
||||
|
||||
CHECK_EQ((array(true) ^ array(false)).item<bool>(), true);
|
||||
CHECK_EQ((array(false) ^ array(false)).item<bool>(), true);
|
||||
CHECK_EQ((array(true) ^ array(true)).item<bool>(), true);
|
||||
CHECK_EQ((array(false) ^ array(true)).item<bool>(), false);
|
||||
CHECK_EQ((power(array(true), array(false))).item<bool>(), true);
|
||||
CHECK_EQ((power(array(false), array(false))).item<bool>(), true);
|
||||
CHECK_EQ((power(array(true), array(true))).item<bool>(), true);
|
||||
CHECK_EQ((power(array(false), array(true))).item<bool>(), false);
|
||||
|
||||
auto x = array(2.0f);
|
||||
CHECK_EQ((x ^ 0.5).item<float>(), doctest::Approx(std::pow(2.0f, 0.5f)));
|
||||
CHECK_EQ((x ^ 2.0f).item<float>(), 4.0f);
|
||||
CHECK_EQ(
|
||||
(power(x, array(0.5))).item<float>(),
|
||||
doctest::Approx(std::pow(2.0f, 0.5f)));
|
||||
CHECK_EQ(power(x, array(2.0f)).item<float>(), 4.0f);
|
||||
|
||||
CHECK(std::isnan((array(-1.0f) ^ 0.5).item<float>()));
|
||||
CHECK(std::isnan((power(array(-1.0f), array(0.5))).item<float>()));
|
||||
|
||||
auto a = complex64_t{0.5, 0.5};
|
||||
auto b = complex64_t{0.5, 0.5};
|
||||
auto expected = std::pow(a, b);
|
||||
auto out = (array(a) ^ array(b)).item<complex64_t>();
|
||||
auto out = (power(array(a), array(b))).item<complex64_t>();
|
||||
CHECK(abs(out.real() - expected.real()) < 1e-7);
|
||||
CHECK(abs(out.imag() - expected.imag()) < 1e-7);
|
||||
}
|
||||
@@ -3230,4 +3227,4 @@ TEST_CASE("test meshgrid") {
|
||||
expected_one = array({1, 2, 3}, {3, 1});
|
||||
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user