diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index 91d9fe56a..3a2dcd532 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -70,7 +70,6 @@ inline float16x8_t neon_fast_exp(float16x8_t x) { x = vdupq_n_f16(float16_t(1.535336188319500e-4f)); x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart); - x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e77d18b0f..ee9aed212 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2927,6 +2927,10 @@ array softmax( const std::vector& axes, bool precise /* = false */, StreamOrDevice s /* = {}*/) { + if (a.size() == 0) { + return a; + } + if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) { auto dtype = at_least_float(a.dtype()); return array( diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index dc86037cb..33d77c626 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -947,12 +947,50 @@ TEST_CASE("test reduction ops") { // Test softmax { - auto x = array({0., 0., 0., 0.}); - auto y = array({0.25, 0.25, 0.25, 0.25}); - CHECK(array_equal(y, softmax(x)).item()); - CHECK(array_equal(y, softmax(x, -1)).item()); - CHECK(array_equal(y, softmax(x, std::vector{-1})).item()); - CHECK(array_equal(y, softmax(x, std::vector{0})).item()); + for (auto t : {float16, bfloat16, float32}) { + const auto rtol = t == float32 ? 1e-5 : 1e-2; + auto x = array({}, t); + CHECK(array_equal(x, softmax(x)).item()); + + // all zeros + x = array({0., 0., 0., 0.}, t); + auto y = array({0.25, 0.25, 0.25, 0.25}, t); + CHECK(array_equal(y, softmax(x)).item()); + CHECK(array_equal(y, softmax(x, -1)).item()); + CHECK(array_equal(y, softmax(x, std::vector{-1})).item()); + CHECK(array_equal(y, softmax(x, std::vector{0})).item()); + + auto ones = array(1.0f, t); + CHECK(array_equal(ones, sum(softmax(x))).item()); + + // all ones + x = array({1., 1., 1., 1.}, t); + CHECK(array_equal(y, softmax(x)).item()); + CHECK(array_equal(ones, sum(softmax(x))).item()); + + // negative values + x = array({-1., -2., -3., -4.}, t); + y = array({0.643914, 0.236883, 0.0871443, 0.0320586}, t); + CHECK(allclose(y, softmax(x), rtol).item()); + CHECK(allclose(ones, sum(softmax(x)), rtol).item()); + + // positive and negative values + x = array({1., 0., -1., 0.}, t); + y = array({0.534447, 0.196612, 0.0723295, 0.196612}, t); + CHECK(allclose(y, softmax(x), rtol).item()); + CHECK(allclose(ones, sum(softmax(x)), rtol).item()); + + // large positive values + x = array({1000., 1000., 1000.}, t); + y = array({0.333333, 0.333333, 0.333333}, t); + CHECK(allclose(y, softmax(x)).item()); + CHECK(array_equal(ones, sum(softmax(x))).item()); + + // large negative values + x = negative(x); + CHECK(allclose(y, softmax(x)).item()); + CHECK(array_equal(ones, sum(softmax(x))).item()); + } } }