Fix neon_fast_exp and add more softmax tests (#1367)

This commit is contained in:
Jeethu Rao
2024-08-28 07:42:42 +01:00
committed by GitHub
parent e6b223df5f
commit bd47e1f066
3 changed files with 48 additions and 7 deletions

View File

@@ -947,12 +947,50 @@ TEST_CASE("test reduction ops") {
// Test softmax
{
auto x = array({0., 0., 0., 0.});
auto y = array({0.25, 0.25, 0.25, 0.25});
CHECK(array_equal(y, softmax(x)).item<bool>());
CHECK(array_equal(y, softmax(x, -1)).item<bool>());
CHECK(array_equal(y, softmax(x, std::vector<int>{-1})).item<bool>());
CHECK(array_equal(y, softmax(x, std::vector<int>{0})).item<bool>());
for (auto t : {float16, bfloat16, float32}) {
const auto rtol = t == float32 ? 1e-5 : 1e-2;
auto x = array({}, t);
CHECK(array_equal(x, softmax(x)).item<bool>());
// all zeros
x = array({0., 0., 0., 0.}, t);
auto y = array({0.25, 0.25, 0.25, 0.25}, t);
CHECK(array_equal(y, softmax(x)).item<bool>());
CHECK(array_equal(y, softmax(x, -1)).item<bool>());
CHECK(array_equal(y, softmax(x, std::vector<int>{-1})).item<bool>());
CHECK(array_equal(y, softmax(x, std::vector<int>{0})).item<bool>());
auto ones = array(1.0f, t);
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
// all ones
x = array({1., 1., 1., 1.}, t);
CHECK(array_equal(y, softmax(x)).item<bool>());
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
// negative values
x = array({-1., -2., -3., -4.}, t);
y = array({0.643914, 0.236883, 0.0871443, 0.0320586}, t);
CHECK(allclose(y, softmax(x), rtol).item<bool>());
CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());
// positive and negative values
x = array({1., 0., -1., 0.}, t);
y = array({0.534447, 0.196612, 0.0723295, 0.196612}, t);
CHECK(allclose(y, softmax(x), rtol).item<bool>());
CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());
// large positive values
x = array({1000., 1000., 1000.}, t);
y = array({0.333333, 0.333333, 0.333333}, t);
CHECK(allclose(y, softmax(x)).item<bool>());
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
// large negative values
x = negative(x);
CHECK(allclose(y, softmax(x)).item<bool>());
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
}
}
}