mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape * more shape * fix * fix
This commit is contained in:
@@ -36,10 +36,10 @@
|
||||
return true; \
|
||||
}
|
||||
|
||||
#define DEFINE_INPUT_OUTPUT_SHAPE() \
|
||||
std::vector<std::vector<int>> output_shapes( \
|
||||
const std::vector<array>& inputs) override { \
|
||||
return {inputs[0].shape()}; \
|
||||
#define DEFINE_INPUT_OUTPUT_SHAPE() \
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
|
||||
override { \
|
||||
return {inputs[0].shape()}; \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -110,8 +110,7 @@ class Primitive {
|
||||
|
||||
/** Get the output shapes of the primitive. This is not required to be
|
||||
* implemented by derived classes, in which case it will throw. */
|
||||
virtual std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs);
|
||||
virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
|
||||
|
||||
virtual ~Primitive() = default;
|
||||
Primitive(const Primitive& other) = delete;
|
||||
@@ -220,6 +219,7 @@ class Arange : public UnaryPrimitive {
|
||||
|
||||
DEFINE_PRINT(Arange)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
double start_;
|
||||
@@ -386,8 +386,7 @@ class ArgReduce : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArgReduce)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@@ -437,11 +436,7 @@ class AsType : public UnaryPrimitive {
|
||||
|
||||
class AsStrided : public UnaryPrimitive {
|
||||
public:
|
||||
explicit AsStrided(
|
||||
Stream stream,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
size_t offset)
|
||||
explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
|
||||
: UnaryPrimitive(stream),
|
||||
shape_(std::move(shape)),
|
||||
strides_(std::move(strides)),
|
||||
@@ -455,8 +450,8 @@ class AsStrided : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::vector<int> shape_;
|
||||
std::vector<size_t> strides_;
|
||||
Shape shape_;
|
||||
Strides strides_;
|
||||
size_t offset_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -527,7 +522,7 @@ class GatherMM : public UnaryPrimitive {
|
||||
|
||||
class Broadcast : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
||||
explicit Broadcast(Stream stream, const Shape& shape)
|
||||
: UnaryPrimitive(stream), shape_(shape) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -539,7 +534,7 @@ class Broadcast : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::vector<int> shape_;
|
||||
Shape shape_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
@@ -586,8 +581,7 @@ class Compiled : public Primitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
void print(std::ostream& os) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -616,6 +610,7 @@ class Concatenate : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Concatenate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
@@ -853,8 +848,7 @@ class DivMod : public Primitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(DivMod)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||
return std::vector{inputs[0].shape(), inputs[0].shape()};
|
||||
}
|
||||
|
||||
@@ -1063,6 +1057,7 @@ class Gather : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Gather)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1339,6 +1334,7 @@ class Matmul : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Matmul)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
};
|
||||
|
||||
class Maximum : public UnaryPrimitive {
|
||||
@@ -1444,8 +1440,7 @@ class NumberOfElements : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(NumberOfElements)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||
return {{}};
|
||||
}
|
||||
|
||||
@@ -1542,6 +1537,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(QuantizedMatmul)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
@@ -1577,7 +1573,7 @@ class GatherQMM : public UnaryPrimitive {
|
||||
|
||||
class RandomBits : public UnaryPrimitive {
|
||||
public:
|
||||
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
||||
explicit RandomBits(Stream stream, const Shape& shape, int width)
|
||||
: UnaryPrimitive(stream), shape_(shape), width_(width) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -1588,7 +1584,7 @@ class RandomBits : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::vector<int> shape_;
|
||||
Shape shape_;
|
||||
int width_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1610,7 +1606,7 @@ class Real : public UnaryPrimitive {
|
||||
|
||||
class Reshape : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
||||
explicit Reshape(Stream stream, const Shape& shape)
|
||||
: UnaryPrimitive(stream), shape_(shape) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -1622,16 +1618,16 @@ class Reshape : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::vector<int> shape_;
|
||||
Shape shape_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
std::pair<bool, std::vector<size_t>> prepare_reshape(
|
||||
static std::pair<bool, Strides> prepare_reshape(
|
||||
const array& in,
|
||||
const array& out);
|
||||
void shared_buffer_reshape(
|
||||
static void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
};
|
||||
|
||||
@@ -1656,8 +1652,7 @@ class Reduce : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (reduce_type_) {
|
||||
@@ -2141,6 +2136,7 @@ class Transpose : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Transpose)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
@@ -2230,24 +2226,9 @@ class Eigh : public Primitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Eigh)
|
||||
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override {
|
||||
auto shape = inputs[0].shape();
|
||||
shape.pop_back(); // Remove last dimension for eigenvalues
|
||||
if (compute_eigenvectors_) {
|
||||
return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
|
||||
} else {
|
||||
return {shape}; // Only eigenvalues
|
||||
}
|
||||
}
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override {
|
||||
if (auto* p = dynamic_cast<const Eigh*>(&other)) {
|
||||
return uplo_ == p->uplo_ &&
|
||||
compute_eigenvectors_ == p->compute_eigenvectors_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
||||
Reference in New Issue
Block a user