shapeless slice update and broadcast when possible (#1727)

This commit is contained in:
Awni Hannun
2024-12-23 11:25:15 -08:00
committed by GitHub
parent 0308e9af71
commit ebfe64b92d
6 changed files with 43 additions and 99 deletions

View File

@@ -533,6 +533,8 @@ class Broadcast : public UnaryPrimitive {
DEFINE_PRINT(Broadcast)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
Shape shape_;
@@ -1943,6 +1945,7 @@ class SliceUpdate : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(SliceUpdate)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
private:
Shape start_indices_;