mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix neon_fast_exp and add more softmax tests (#1367)
This commit is contained in:
@@ -2927,6 +2927,10 @@ array softmax(
|
||||
const std::vector<int>& axes,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (a.size() == 0) {
|
||||
return a;
|
||||
}
|
||||
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
|
||||
Reference in New Issue
Block a user