mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 10:02:12 +08:00
Fix neon_fast_exp and add more softmax tests (#1367)
This commit is contained in:
parent
e6b223df5f
commit
bd47e1f066
@ -70,7 +70,6 @@ inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
|
@ -2927,6 +2927,10 @@ array softmax(
|
||||
const std::vector<int>& axes,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (a.size() == 0) {
|
||||
return a;
|
||||
}
|
||||
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
|
@ -947,12 +947,50 @@ TEST_CASE("test reduction ops") {
|
||||
|
||||
// Test softmax
|
||||
{
|
||||
auto x = array({0., 0., 0., 0.});
|
||||
auto y = array({0.25, 0.25, 0.25, 0.25});
|
||||
CHECK(array_equal(y, softmax(x)).item<bool>());
|
||||
CHECK(array_equal(y, softmax(x, -1)).item<bool>());
|
||||
CHECK(array_equal(y, softmax(x, std::vector<int>{-1})).item<bool>());
|
||||
CHECK(array_equal(y, softmax(x, std::vector<int>{0})).item<bool>());
|
||||
for (auto t : {float16, bfloat16, float32}) {
|
||||
const auto rtol = t == float32 ? 1e-5 : 1e-2;
|
||||
auto x = array({}, t);
|
||||
CHECK(array_equal(x, softmax(x)).item<bool>());
|
||||
|
||||
// all zeros
|
||||
x = array({0., 0., 0., 0.}, t);
|
||||
auto y = array({0.25, 0.25, 0.25, 0.25}, t);
|
||||
CHECK(array_equal(y, softmax(x)).item<bool>());
|
||||
CHECK(array_equal(y, softmax(x, -1)).item<bool>());
|
||||
CHECK(array_equal(y, softmax(x, std::vector<int>{-1})).item<bool>());
|
||||
CHECK(array_equal(y, softmax(x, std::vector<int>{0})).item<bool>());
|
||||
|
||||
auto ones = array(1.0f, t);
|
||||
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
|
||||
|
||||
// all ones
|
||||
x = array({1., 1., 1., 1.}, t);
|
||||
CHECK(array_equal(y, softmax(x)).item<bool>());
|
||||
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
|
||||
|
||||
// negative values
|
||||
x = array({-1., -2., -3., -4.}, t);
|
||||
y = array({0.643914, 0.236883, 0.0871443, 0.0320586}, t);
|
||||
CHECK(allclose(y, softmax(x), rtol).item<bool>());
|
||||
CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());
|
||||
|
||||
// positive and negative values
|
||||
x = array({1., 0., -1., 0.}, t);
|
||||
y = array({0.534447, 0.196612, 0.0723295, 0.196612}, t);
|
||||
CHECK(allclose(y, softmax(x), rtol).item<bool>());
|
||||
CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());
|
||||
|
||||
// large positive values
|
||||
x = array({1000., 1000., 1000.}, t);
|
||||
y = array({0.333333, 0.333333, 0.333333}, t);
|
||||
CHECK(allclose(y, softmax(x)).item<bool>());
|
||||
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
|
||||
|
||||
// large negative values
|
||||
x = negative(x);
|
||||
CHECK(allclose(y, softmax(x)).item<bool>());
|
||||
CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user