mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
148
mlx/primitives.h
148
mlx/primitives.h
@@ -203,6 +203,9 @@ class AddMM : public UnaryPrimitive {
|
||||
DEFINE_PRINT(AddMM)
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<float, float> state() const {
|
||||
return {alpha_, beta_};
|
||||
};
|
||||
|
||||
private:
|
||||
const float alpha_;
|
||||
@@ -220,6 +223,9 @@ class Arange : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Arange)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::tuple<double, double, double> state() const {
|
||||
return {start_, stop_, step_};
|
||||
};
|
||||
|
||||
private:
|
||||
double start_;
|
||||
@@ -361,6 +367,9 @@ class ArgPartition : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArgPartition)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<int, int> state() const {
|
||||
return {kth_, axis_};
|
||||
};
|
||||
|
||||
private:
|
||||
int kth_;
|
||||
@@ -387,6 +396,9 @@ class ArgReduce : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArgReduce)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::pair<ReduceType, int> state() const {
|
||||
return {reduce_type_, axis_};
|
||||
};
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@@ -407,6 +419,9 @@ class ArgSort : public UnaryPrimitive {
|
||||
DEFINE_PRINT(ArgSort)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
int state() const {
|
||||
return axis_;
|
||||
};
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@@ -427,6 +442,9 @@ class AsType : public UnaryPrimitive {
|
||||
DEFINE_PRINT(AsType)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
Dtype state() const {
|
||||
return dtype_;
|
||||
};
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
@@ -448,6 +466,9 @@ class AsStrided : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(AsStrided)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(shape_, strides_, offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
@@ -472,6 +493,9 @@ class BitwiseBinary : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
void print(std::ostream& os) override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return op_;
|
||||
}
|
||||
|
||||
private:
|
||||
Op op_;
|
||||
@@ -493,6 +517,9 @@ class BlockMaskedMM : public UnaryPrimitive {
|
||||
|
||||
DEFINE_PRINT(BlockMaskedMM)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return block_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
int block_size_;
|
||||
@@ -532,6 +559,9 @@ class Broadcast : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Broadcast)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
return shape_;
|
||||
};
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
@@ -613,6 +643,9 @@ class Concatenate : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Concatenate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return axis_;
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@@ -684,6 +717,15 @@ class Convolution : public UnaryPrimitive {
|
||||
|
||||
DEFINE_PRINT(Convolution)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> padding_;
|
||||
@@ -912,6 +954,9 @@ class Equal : public UnaryPrimitive {
|
||||
os << "Equal";
|
||||
}
|
||||
}
|
||||
auto state() const {
|
||||
return equal_nan_;
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1001,6 +1046,9 @@ class ExpandDims : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||
auto state() const {
|
||||
return axes_;
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1024,6 +1072,9 @@ class FFT : public UnaryPrimitive {
|
||||
DEFINE_PRINT(FFT)
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(axes_, inverse_, real_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> axes_;
|
||||
@@ -1048,6 +1099,9 @@ class Flatten : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int start_axis, int end_axis);
|
||||
auto state() const {
|
||||
return std::make_pair(start_axis_, end_axis_);
|
||||
}
|
||||
|
||||
private:
|
||||
int start_axis_;
|
||||
@@ -1103,6 +1157,9 @@ class Gather : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Gather)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::pair<std::vector<int>, std::vector<int>> state() const {
|
||||
return {axes_, slice_sizes_};
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1158,6 +1215,9 @@ class Hadamard : public UnaryPrimitive {
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return scale_;
|
||||
}
|
||||
|
||||
private:
|
||||
float scale_;
|
||||
@@ -1260,6 +1320,10 @@ class Log : public UnaryPrimitive {
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
Base state() const {
|
||||
return base_;
|
||||
};
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (base_) {
|
||||
case e:
|
||||
@@ -1488,6 +1552,9 @@ class NumberOfElements : public UnaryPrimitive {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||
return {{}};
|
||||
}
|
||||
std::tuple<std::vector<int>, bool, Dtype> state() const {
|
||||
return {axes_, inverted_, dtype_};
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
@@ -1516,6 +1583,9 @@ class Pad : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Pad)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
@@ -1538,6 +1608,9 @@ class Partition : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Partition)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(kth_, axis_);
|
||||
};
|
||||
|
||||
private:
|
||||
int kth_;
|
||||
@@ -1583,6 +1656,9 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
DEFINE_PRINT(QuantizedMatmul)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
@@ -1607,6 +1683,9 @@ class GatherQMM : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(GatherQMM)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
@@ -1627,6 +1706,9 @@ class RandomBits : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(RandomBits)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<std::vector<int>, int> state() const {
|
||||
return {shape_, width_};
|
||||
};
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
@@ -1661,6 +1743,9 @@ class Reshape : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Reshape)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
return shape_;
|
||||
};
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
@@ -1712,6 +1797,9 @@ class Reduce : public UnaryPrimitive {
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<ReduceType, std::vector<int>> state() const {
|
||||
return {reduce_type_, axes_};
|
||||
};
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@@ -1777,6 +1865,9 @@ class Scan : public UnaryPrimitive {
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@@ -1823,6 +1914,9 @@ class Scatter : public UnaryPrimitive {
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<ReduceType, std::vector<int>> state() const {
|
||||
return {reduce_type_, axes_};
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1917,6 +2011,9 @@ class Slice : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Slice)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(start_indices_, end_indices_, strides_);
|
||||
}
|
||||
|
||||
private:
|
||||
Shape start_indices_;
|
||||
@@ -1946,6 +2043,9 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
DEFINE_PRINT(SliceUpdate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_tuple(start_indices_, end_indices_, strides_);
|
||||
}
|
||||
|
||||
private:
|
||||
Shape start_indices_;
|
||||
@@ -1969,6 +2069,9 @@ class Softmax : public UnaryPrimitive {
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return precise_;
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1988,6 +2091,9 @@ class Sort : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Sort)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return axis_;
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@@ -2009,6 +2115,9 @@ class Split : public Primitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Split)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<std::vector<int>, int> state() const {
|
||||
return {indices_, axis_};
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
@@ -2046,6 +2155,9 @@ class Sqrt : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return recip_;
|
||||
}
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
if (recip_) {
|
||||
@@ -2109,6 +2221,9 @@ class Squeeze : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||
auto state() const {
|
||||
return axes_;
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -2164,6 +2279,9 @@ class Unflatten : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int axis, const Shape& shape);
|
||||
auto state() const {
|
||||
return std::make_pair(axis_, shape_);
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@@ -2171,21 +2289,6 @@ class Unflatten : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Uniform : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Uniform)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class View : public UnaryPrimitive {
|
||||
public:
|
||||
explicit View(Stream stream, Dtype dtype)
|
||||
@@ -2197,6 +2300,9 @@ class View : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
void print(std::ostream& os) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
@@ -2215,6 +2321,9 @@ class Transpose : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Transpose)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::vector<int> state() const {
|
||||
return axes_;
|
||||
};
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
@@ -2266,6 +2375,9 @@ class Inverse : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Inverse)
|
||||
auto state() const {
|
||||
return std::make_pair(tri_, upper_);
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& output);
|
||||
@@ -2280,6 +2392,9 @@ class Cholesky : public UnaryPrimitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
auto state() const {
|
||||
return upper_;
|
||||
}
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Cholesky)
|
||||
@@ -2307,6 +2422,9 @@ class Eigh : public Primitive {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(uplo_, compute_eigenvectors_);
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
||||
Reference in New Issue
Block a user