mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
shapeless slice update and broadcast when possible (#1727)
This commit is contained in:
@@ -533,6 +533,8 @@ class Broadcast : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Broadcast)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
|
||||
@@ -1943,6 +1945,7 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(SliceUpdate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
Shape start_indices_;
|
||||
|
||||
Reference in New Issue
Block a user