Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -547,6 +547,31 @@ class GatherMM : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class BroadcastAxes : public UnaryPrimitive {
public:
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
: UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(BroadcastAxes)
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(
const std::vector<array>& inputs,
const std::vector<int>& ignore_axes);
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return ignore_axes_;
}
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> ignore_axes_;
};
class Broadcast : public UnaryPrimitive {
public:
explicit Broadcast(Stream stream, const Shape& shape)
@@ -558,13 +583,13 @@ class Broadcast : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Broadcast)
static Shape output_shape(const std::vector<array>& inputs);
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
return shape_;
};
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
Shape shape_;