mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -2057,6 +2057,51 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class DynamicSlice : public UnaryPrimitive {
|
||||
public:
|
||||
explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
|
||||
: UnaryPrimitive(stream),
|
||||
axes_(std::move(axes)),
|
||||
slice_size_(std::move(slice_size)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(DynamicSlice)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_pair(axes_, slice_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
Shape slice_size_;
|
||||
};
|
||||
|
||||
class DynamicSliceUpdate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
|
||||
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(DynamicSliceUpdate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return axes_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
};
|
||||
|
||||
class Softmax : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Softmax(Stream stream, bool precise)
|
||||
|
||||
Reference in New Issue
Block a user