Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -2057,6 +2057,51 @@ class SliceUpdate : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class DynamicSlice : public UnaryPrimitive {
public:
explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
slice_size_(std::move(slice_size)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(DynamicSlice)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_pair(axes_, slice_size_);
}
private:
std::vector<int> axes_;
Shape slice_size_;
};
class DynamicSliceUpdate : public UnaryPrimitive {
public:
explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(DynamicSliceUpdate)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return axes_;
}
private:
std::vector<int> axes_;
};
class Softmax : public UnaryPrimitive {
public:
explicit Softmax(Stream stream, bool precise)