Custom logsumexp (#2028)

* initial custom logsumexp

* more tests

* comments + fix
This commit is contained in:
Awni Hannun
2025-03-31 07:36:55 -07:00
committed by GitHub
parent ec2854b13a
commit de5f38fd48
27 changed files with 590 additions and 255 deletions

View File

@@ -1350,6 +1350,20 @@ class LogAddExp : public UnaryPrimitive {
DEFINE_INPUT_OUTPUT_SHAPE()
};
class LogSumExp : public UnaryPrimitive {
public:
explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(LogSumExp)
DEFINE_DEFAULT_IS_EQUIVALENT()
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
};
class Matmul : public UnaryPrimitive {
public:
explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}