More primitives for compiling with shapeless (#1653)

* more shapeless and more Shape

* more shape

* fix

* fix
This commit is contained in:
Awni Hannun
2024-12-06 11:29:18 -08:00
committed by GitHub
parent 95c4a2e3af
commit d0b6cb0425
5 changed files with 160 additions and 81 deletions

View File

@@ -36,10 +36,10 @@
return true; \
}
#define DEFINE_INPUT_OUTPUT_SHAPE() \
std::vector<std::vector<int>> output_shapes( \
const std::vector<array>& inputs) override { \
return {inputs[0].shape()}; \
#define DEFINE_INPUT_OUTPUT_SHAPE() \
std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
override { \
return {inputs[0].shape()}; \
}
namespace mlx::core {
@@ -110,8 +110,7 @@ class Primitive {
/** Get the output shapes of the primitive. This is not required to be
* implemented by derived classes, in which case it will throw. */
virtual std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs);
virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
virtual ~Primitive() = default;
Primitive(const Primitive& other) = delete;
@@ -220,6 +219,7 @@ class Arange : public UnaryPrimitive {
DEFINE_PRINT(Arange)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
double start_;
@@ -386,8 +386,7 @@ class ArgReduce : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
ReduceType reduce_type_;
@@ -437,11 +436,7 @@ class AsType : public UnaryPrimitive {
class AsStrided : public UnaryPrimitive {
public:
explicit AsStrided(
Stream stream,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset)
explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
: UnaryPrimitive(stream),
shape_(std::move(shape)),
strides_(std::move(strides)),
@@ -455,8 +450,8 @@ class AsStrided : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
std::vector<size_t> strides_;
Shape shape_;
Strides strides_;
size_t offset_;
void eval(const std::vector<array>& inputs, array& out);
@@ -527,7 +522,7 @@ class GatherMM : public UnaryPrimitive {
class Broadcast : public UnaryPrimitive {
public:
explicit Broadcast(Stream stream, const std::vector<int>& shape)
explicit Broadcast(Stream stream, const Shape& shape)
: UnaryPrimitive(stream), shape_(shape) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
@@ -539,7 +534,7 @@ class Broadcast : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
Shape shape_;
void eval(const std::vector<array>& inputs, array& out);
};
@@ -586,8 +581,7 @@ class Compiled : public Primitive {
DEFINE_VMAP()
DEFINE_GRADS()
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override;
@@ -616,6 +610,7 @@ class Concatenate : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Concatenate)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
int axis_;
@@ -853,8 +848,7 @@ class DivMod : public Primitive {
DEFINE_GRADS()
DEFINE_PRINT(DivMod)
DEFINE_DEFAULT_IS_EQUIVALENT()
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
return std::vector{inputs[0].shape(), inputs[0].shape()};
}
@@ -1063,6 +1057,7 @@ class Gather : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Gather)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1339,6 +1334,7 @@ class Matmul : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
};
class Maximum : public UnaryPrimitive {
@@ -1444,8 +1440,7 @@ class NumberOfElements : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(NumberOfElements)
bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
return {{}};
}
@@ -1542,6 +1537,7 @@ class QuantizedMatmul : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(QuantizedMatmul)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
int group_size_;
@@ -1577,7 +1573,7 @@ class GatherQMM : public UnaryPrimitive {
class RandomBits : public UnaryPrimitive {
public:
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
explicit RandomBits(Stream stream, const Shape& shape, int width)
: UnaryPrimitive(stream), shape_(shape), width_(width) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
@@ -1588,7 +1584,7 @@ class RandomBits : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
Shape shape_;
int width_;
void eval(const std::vector<array>& inputs, array& out);
@@ -1610,7 +1606,7 @@ class Real : public UnaryPrimitive {
class Reshape : public UnaryPrimitive {
public:
explicit Reshape(Stream stream, const std::vector<int>& shape)
explicit Reshape(Stream stream, const Shape& shape)
: UnaryPrimitive(stream), shape_(shape) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
@@ -1622,16 +1618,16 @@ class Reshape : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
Shape shape_;
void eval(const std::vector<array>& inputs, array& out);
std::pair<bool, std::vector<size_t>> prepare_reshape(
static std::pair<bool, Strides> prepare_reshape(
const array& in,
const array& out);
void shared_buffer_reshape(
static void shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
array& out);
};
@@ -1656,8 +1652,7 @@ class Reduce : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
void print(std::ostream& os) override {
switch (reduce_type_) {
@@ -2141,6 +2136,7 @@ class Transpose : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Transpose)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
std::vector<int> axes_;
@@ -2230,24 +2226,9 @@ class Eigh : public Primitive {
DEFINE_VMAP()
DEFINE_PRINT(Eigh)
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
auto shape = inputs[0].shape();
shape.pop_back(); // Remove last dimension for eigenvalues
if (compute_eigenvectors_) {
return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
} else {
return {shape}; // Only eigenvalues
}
}
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override {
if (auto* p = dynamic_cast<const Eigh*>(&other)) {
return uplo_ == p->uplo_ &&
compute_eigenvectors_ == p->compute_eigenvectors_;
}
return false;
}
bool is_equivalent(const Primitive& other) const override;
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);