mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -547,6 +547,31 @@ class GatherMM : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class BroadcastAxes : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
|
||||
: UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(BroadcastAxes)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
static Shape output_shape(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& ignore_axes);
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return ignore_axes_;
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::vector<int> ignore_axes_;
|
||||
};
|
||||
|
||||
class Broadcast : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Broadcast(Stream stream, const Shape& shape)
|
||||
@@ -558,13 +583,13 @@ class Broadcast : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Broadcast)
|
||||
static Shape output_shape(const std::vector<array>& inputs);
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
return shape_;
|
||||
};
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user