mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom logsumexp (#2028)
* initial custom logsumexp * more tests * comments + fix
This commit is contained in:
@@ -1350,6 +1350,20 @@ class LogAddExp : public UnaryPrimitive {
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
};
|
||||
|
||||
class LogSumExp : public UnaryPrimitive {
|
||||
public:
|
||||
explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogSumExp)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
};
|
||||
|
||||
class Matmul : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
Reference in New Issue
Block a user