Fix logsumexp edge case (#740)

* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
This commit is contained in:
Awni Hannun
2024-02-25 08:39:55 -08:00
committed by GitHub
parent ac02cf33bd
commit e6418781ab
12 changed files with 112 additions and 64 deletions

View File

@@ -2,16 +2,6 @@
#include "mlx/backend/metal/kernels/binary.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
device const T* a,