Fix neon_fast_exp and add more softmax tests (#1367)

This commit is contained in:
Jeethu Rao
2024-08-28 07:42:42 +01:00
committed by GitHub
parent e6b223df5f
commit bd47e1f066
3 changed files with 48 additions and 7 deletions

View File

@@ -70,7 +70,6 @@ inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);

View File

@@ -2927,6 +2927,10 @@ array softmax(
const std::vector<int>& axes,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) {
if (a.size() == 0) {
return a;
}
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
return array(