Export / import functions to / from a file (#1642)

* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
This commit is contained in:
Awni Hannun
2024-12-24 11:19:13 -08:00
committed by GitHub
parent 935c8c4bb1
commit 4ba0c24a8f
35 changed files with 2239 additions and 90 deletions

View File

@@ -203,6 +203,9 @@ class AddMM : public UnaryPrimitive {
DEFINE_PRINT(AddMM)
bool is_equivalent(const Primitive& other) const override;
std::pair<float, float> state() const {
return {alpha_, beta_};
};
private:
const float alpha_;
@@ -220,6 +223,9 @@ class Arange : public UnaryPrimitive {
DEFINE_PRINT(Arange)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::tuple<double, double, double> state() const {
return {start_, stop_, step_};
};
private:
double start_;
@@ -361,6 +367,9 @@ class ArgPartition : public UnaryPrimitive {
DEFINE_PRINT(ArgPartition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
std::pair<int, int> state() const {
return {kth_, axis_};
};
private:
int kth_;
@@ -387,6 +396,9 @@ class ArgReduce : public UnaryPrimitive {
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::pair<ReduceType, int> state() const {
return {reduce_type_, axis_};
};
private:
ReduceType reduce_type_;
@@ -407,6 +419,9 @@ class ArgSort : public UnaryPrimitive {
DEFINE_PRINT(ArgSort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
int state() const {
return axis_;
};
private:
int axis_;
@@ -427,6 +442,9 @@ class AsType : public UnaryPrimitive {
DEFINE_PRINT(AsType)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
Dtype state() const {
return dtype_;
};
private:
Dtype dtype_;
@@ -448,6 +466,9 @@ class AsStrided : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(AsStrided)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(shape_, strides_, offset_);
}
private:
Shape shape_;
@@ -472,6 +493,9 @@ class BitwiseBinary : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return op_;
}
private:
Op op_;
@@ -493,6 +517,9 @@ class BlockMaskedMM : public UnaryPrimitive {
DEFINE_PRINT(BlockMaskedMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return block_size_;
}
private:
int block_size_;
@@ -532,6 +559,9 @@ class Broadcast : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Broadcast)
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
return shape_;
};
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
@@ -613,6 +643,9 @@ class Concatenate : public UnaryPrimitive {
DEFINE_PRINT(Concatenate)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return axis_;
}
private:
int axis_;
@@ -684,6 +717,15 @@ class Convolution : public UnaryPrimitive {
DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_);
}
private:
std::vector<int> padding_;
@@ -912,6 +954,9 @@ class Equal : public UnaryPrimitive {
os << "Equal";
}
}
auto state() const {
return equal_nan_;
};
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1001,6 +1046,9 @@ class ExpandDims : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
auto state() const {
return axes_;
}
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1024,6 +1072,9 @@ class FFT : public UnaryPrimitive {
DEFINE_PRINT(FFT)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(axes_, inverse_, real_);
}
private:
std::vector<size_t> axes_;
@@ -1048,6 +1099,9 @@ class Flatten : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int start_axis, int end_axis);
auto state() const {
return std::make_pair(start_axis_, end_axis_);
}
private:
int start_axis_;
@@ -1103,6 +1157,9 @@ class Gather : public UnaryPrimitive {
DEFINE_PRINT(Gather)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::pair<std::vector<int>, std::vector<int>> state() const {
return {axes_, slice_sizes_};
}
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1158,6 +1215,9 @@ class Hadamard : public UnaryPrimitive {
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return scale_;
}
private:
float scale_;
@@ -1260,6 +1320,10 @@ class Log : public UnaryPrimitive {
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
Base state() const {
return base_;
};
void print(std::ostream& os) override {
switch (base_) {
case e:
@@ -1488,6 +1552,9 @@ class NumberOfElements : public UnaryPrimitive {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
return {{}};
}
std::tuple<std::vector<int>, bool, Dtype> state() const {
return {axes_, inverted_, dtype_};
}
private:
std::vector<int> axes_;
@@ -1516,6 +1583,9 @@ class Pad : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Pad)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
}
private:
std::vector<int> axes_;
@@ -1538,6 +1608,9 @@ class Partition : public UnaryPrimitive {
DEFINE_PRINT(Partition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(kth_, axis_);
};
private:
int kth_;
@@ -1583,6 +1656,9 @@ class QuantizedMatmul : public UnaryPrimitive {
DEFINE_PRINT(QuantizedMatmul)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
}
private:
int group_size_;
@@ -1607,6 +1683,9 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
}
private:
int group_size_;
@@ -1627,6 +1706,9 @@ class RandomBits : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(RandomBits)
bool is_equivalent(const Primitive& other) const override;
std::pair<std::vector<int>, int> state() const {
return {shape_, width_};
};
private:
Shape shape_;
@@ -1661,6 +1743,9 @@ class Reshape : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Reshape)
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
return shape_;
};
private:
Shape shape_;
@@ -1712,6 +1797,9 @@ class Reduce : public UnaryPrimitive {
}
}
bool is_equivalent(const Primitive& other) const override;
std::pair<ReduceType, std::vector<int>> state() const {
return {reduce_type_, axes_};
};
private:
ReduceType reduce_type_;
@@ -1777,6 +1865,9 @@ class Scan : public UnaryPrimitive {
}
}
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
}
private:
ReduceType reduce_type_;
@@ -1823,6 +1914,9 @@ class Scatter : public UnaryPrimitive {
}
}
bool is_equivalent(const Primitive& other) const override;
std::pair<ReduceType, std::vector<int>> state() const {
return {reduce_type_, axes_};
};
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1917,6 +2011,9 @@ class Slice : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Slice)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(start_indices_, end_indices_, strides_);
}
private:
Shape start_indices_;
@@ -1946,6 +2043,9 @@ class SliceUpdate : public UnaryPrimitive {
DEFINE_PRINT(SliceUpdate)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(start_indices_, end_indices_, strides_);
}
private:
Shape start_indices_;
@@ -1969,6 +2069,9 @@ class Softmax : public UnaryPrimitive {
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return precise_;
};
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1988,6 +2091,9 @@ class Sort : public UnaryPrimitive {
DEFINE_PRINT(Sort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return axis_;
}
private:
int axis_;
@@ -2009,6 +2115,9 @@ class Split : public Primitive {
DEFINE_GRADS()
DEFINE_PRINT(Split)
bool is_equivalent(const Primitive& other) const override;
std::pair<std::vector<int>, int> state() const {
return {indices_, axis_};
};
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
@@ -2046,6 +2155,9 @@ class Sqrt : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return recip_;
}
void print(std::ostream& os) override {
if (recip_) {
@@ -2109,6 +2221,9 @@ class Squeeze : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
auto state() const {
return axes_;
};
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -2164,6 +2279,9 @@ class Unflatten : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int axis, const Shape& shape);
auto state() const {
return std::make_pair(axis_, shape_);
}
private:
int axis_;
@@ -2171,21 +2289,6 @@ class Unflatten : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Uniform : public UnaryPrimitive {
public:
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(Uniform)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class View : public UnaryPrimitive {
public:
explicit View(Stream stream, Dtype dtype)
@@ -2197,6 +2300,9 @@ class View : public UnaryPrimitive {
DEFINE_VMAP()
void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return dtype_;
}
private:
Dtype dtype_;
@@ -2215,6 +2321,9 @@ class Transpose : public UnaryPrimitive {
DEFINE_PRINT(Transpose)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::vector<int> state() const {
return axes_;
};
private:
std::vector<int> axes_;
@@ -2266,6 +2375,9 @@ class Inverse : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(Inverse)
auto state() const {
return std::make_pair(tri_, upper_);
}
private:
void eval(const std::vector<array>& inputs, array& output);
@@ -2280,6 +2392,9 @@ class Cholesky : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
auto state() const {
return upper_;
}
DEFINE_VMAP()
DEFINE_PRINT(Cholesky)
@@ -2307,6 +2422,9 @@ class Eigh : public Primitive {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(uplo_, compute_eigenvectors_);
}
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);