diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 1f9556346..741385844 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -8,6 +7,7 @@ #include "mlx/allocator.h" #include "mlx/compile.h" #include "mlx/compile_impl.h" +#include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" @@ -73,11 +73,18 @@ bool is_fusable(const Primitive& p) { } bool allows_shapeless(const Primitive& p) { - return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) || - is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || - typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || - typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) || - typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements); + return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) || + is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) || + typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) || + typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) || + typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) || + typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) || + typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) || + typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) || + typeid(p) == typeid(fast::AffineQuantize) || + typeid(p) == typeid(fast::LayerNorm) || + typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) || + typeid(p) == typeid(fast::ScaledDotProductAttention); } Compiled::Compiled( @@ -93,23 +100,23 @@ Compiled::Compiled( constant_ids_(std::move(constant_ids)) {} std::vector Compiled::vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums, - const std::vector& outputs) { + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&) { throw std::runtime_error("[Compiled] Cannot vjp primitive."); } std::vector Compiled::jvp( - const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector&, + const std::vector&, + const std::vector&) { throw std::runtime_error("[Compiled] Cannot jvp primitive."); } std::pair, std::vector> Compiled::vmap( - const std::vector& inputs, - const std::vector& axes) { + const std::vector&, + const std::vector&) { throw std::runtime_error("[Compiled] Cannot vmap primitive."); } @@ -134,13 +141,12 @@ void Compiled::print(std::ostream& os) { } } -std::vector> Compiled::output_shapes( - const std::vector& inputs) { +std::vector Compiled::output_shapes(const std::vector& inputs) { size_t nd = 0; for (auto& in : inputs) { nd = std::max(nd, in.ndim()); } - std::vector out_shape(nd, 0); + Shape out_shape(nd, 0); for (auto& in : inputs) { auto dd = nd - in.ndim(); for (auto i = dd; i < nd; ++i) { @@ -148,7 +154,7 @@ std::vector> Compiled::output_shapes( } } // All outputs have the same shape - return std::vector>(outputs_.size(), out_shape); + return std::vector(outputs_.size(), out_shape); } namespace detail { @@ -553,14 +559,12 @@ void compile_fuse( // - Collect inputs to the new compiled primitive // - Add fusable primitives to a tape in the correct order - std::function&)> - recurse; + std::function recurse; std::unordered_set cache; recurse = [&](const array& a, int depth, const Stream& s, - const std::vector& shape) { + const Shape& shape) { if (cache.find(a.id()) != cache.end()) { return; } @@ -667,7 +671,7 @@ void compile_fuse( } old_outputs.push_back(arr); - std::vector> shapes; + std::vector shapes; std::vector types; for (auto& o : old_outputs) { if (o.shape() != old_outputs.back().shape()) { @@ -771,7 +775,7 @@ std::vector compile_replace( for (auto& o : trace_out) { types.push_back(o.dtype()); } - std::vector> shapes; + std::vector shapes; if (shapeless) { shapes = a.primitive().output_shapes(real_inputs); } else { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 5a04af0bd..f047f8027 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -915,6 +915,31 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } +bool AffineQuantize::is_equivalent(const Primitive& other) const { + const AffineQuantize& p_other = static_cast(other); + return ( + p_other.group_size_ == group_size_ && p_other.bits_ == bits_ && + p_other.dequantize_ == dequantize_); +} + +std::vector AffineQuantize::output_shapes( + const std::vector& inputs) { + auto& w = inputs[0]; + if (dequantize_) { + auto out_size = w.shape(-1) * 32 / bits_; + auto out_shape = w.shape(); + out_shape.back() = out_size; + return {std::move(out_shape)}; + } else { + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) * bits_ / 32; + auto sshape = w.shape(); + sshape.back() = w.shape(-1) / group_size_; + auto bshape = sshape; + return {std::move(wq_shape), std::move(sshape), std::move(bshape)}; + } +} + std::string write_signature( std::string func_name, const std::string& header, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 30db282ff..0ec316327 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -58,6 +58,7 @@ class RMSNorm : public Custom { DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() private: std::function(std::vector)> fallback_; @@ -110,6 +111,7 @@ class LayerNorm : public Custom { DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() private: std::function(std::vector)> fallback_; @@ -173,6 +175,7 @@ class RoPE : public Custom { DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() private: std::function(std::vector)> fallback_; @@ -207,6 +210,7 @@ class ScaledDotProductAttention : public Custom { bool is_equivalent(const Primitive& other) const override; DEFINE_PRINT(ScaledDotProductAttention); + DEFINE_INPUT_OUTPUT_SHAPE() private: std::function(std::vector)> fallback_; @@ -235,6 +239,9 @@ class AffineQuantize : public Custom { DEFINE_PRINT(AffineQuantize); + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + private: std::function(std::vector)> fallback_; int group_size_; diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9d9ecd588..06be5f2a5 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -267,6 +267,11 @@ bool Arange::is_equivalent(const Primitive& other) const { step_ == a_other.step_); } +std::vector Arange::output_shapes(const std::vector&) { + auto real_size = std::ceil((stop_ - start_) / step_); + return {{std::max(static_cast(real_size), 0)}}; +} + std::vector ArcCos::vjp( const std::vector& primals, const std::vector& cotangents, @@ -534,11 +539,10 @@ std::pair, std::vector> ArgSort::vmap( return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes}; } -std::vector> ArgReduce::output_shapes( - const std::vector& inputs) { +std::vector ArgReduce::output_shapes(const std::vector& inputs) { auto out_shape = inputs[0].shape(); out_shape[axis_] = 1; - return {out_shape}; + return {std::move(out_shape)}; } bool ArgSort::is_equivalent(const Primitive& other) const { @@ -787,6 +791,23 @@ std::pair, std::vector> Eigh::vmap( return {outputs, std::vector(outputs.size(), ax)}; } +std::vector Eigh::output_shapes(const std::vector& inputs) { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return { + std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {std::move(shape)}; // Only eigenvalues + } +} + +bool Eigh::is_equivalent(const Primitive& other) const { + auto& e_other = static_cast(other); + return uplo_ == e_other.uplo_ && + compute_eigenvectors_ == e_other.compute_eigenvectors_; +} + std::vector Concatenate::vjp( const std::vector& primals, const std::vector& cotangents, @@ -881,6 +902,15 @@ bool Concatenate::is_equivalent(const Primitive& other) const { return axis_ == c_other.axis_; } +std::vector Concatenate::output_shapes( + const std::vector& inputs) { + auto shape = inputs[0].shape(); + for (int i = 1; i < inputs.size(); ++i) { + shape[axis_] += inputs[i].shape(axis_); + } + return {std::move(shape)}; +} + std::pair, std::vector> Conjugate::vmap( const std::vector& inputs, const std::vector& axes) { @@ -1811,6 +1841,15 @@ bool Gather::is_equivalent(const Primitive& other) const { return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_; } +std::vector Gather::output_shapes(const std::vector& inputs) { + Shape out_shape; + if (inputs.size() > 1) { + out_shape = inputs[0].shape(); + } + out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end()); + return {std::move(out_shape)}; +} + std::pair, std::vector> Greater::vmap( const std::vector& inputs, const std::vector& axes) { @@ -2184,6 +2223,12 @@ std::pair, std::vector> Matmul::vmap( return {{matmul(a, b, stream())}, {0}}; } +std::vector Matmul::output_shapes(const std::vector& inputs) { + auto out_shape = inputs[0].shape(); + out_shape.back() = inputs[1].shape(-1); + return {std::move(out_shape)}; +} + std::vector Maximum::vjp( const std::vector& primals, const std::vector& cotangents, @@ -2608,6 +2653,15 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const { transpose_ == qm_other.transpose_; } +std::vector QuantizedMatmul::output_shapes( + const std::vector& inputs) { + auto& w = inputs[1]; + int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1) * 32 / bits_; + auto out_shape = inputs[0].shape(); + out_shape.back() = w_outer_dims; + return {std::move(out_shape)}; +} + std::pair, std::vector> GatherQMM::vmap( const std::vector& inputs, const std::vector& axes) { @@ -2937,13 +2991,12 @@ bool Reduce::is_equivalent(const Primitive& other) const { return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_; } -std::vector> Reduce::output_shapes( - const std::vector& inputs) { - std::vector out_shape = inputs[0].shape(); +std::vector Reduce::output_shapes(const std::vector& inputs) { + auto out_shape = inputs[0].shape(); for (auto i : axes_) { out_shape[i] = 1; } - return {out_shape}; + return {std::move(out_shape)}; } std::vector Round::vjp( @@ -4209,6 +4262,15 @@ bool Transpose::is_equivalent(const Primitive& other) const { return axes_ == t_other.axes_; } +std::vector Transpose::output_shapes(const std::vector& inputs) { + auto& in = inputs[0]; + Shape shape(in.ndim(), 0); + for (int i = 0; i < axes_.size(); ++i) { + shape[i] = in.shape()[axes_[i]]; + } + return {std::move(shape)}; +} + std::pair, std::vector> NumberOfElements::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 13022db24..0b0359a1b 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -36,10 +36,10 @@ return true; \ } -#define DEFINE_INPUT_OUTPUT_SHAPE() \ - std::vector> output_shapes( \ - const std::vector& inputs) override { \ - return {inputs[0].shape()}; \ +#define DEFINE_INPUT_OUTPUT_SHAPE() \ + std::vector output_shapes(const std::vector& inputs) \ + override { \ + return {inputs[0].shape()}; \ } namespace mlx::core { @@ -110,8 +110,7 @@ class Primitive { /** Get the output shapes of the primitive. This is not required to be * implemented by derived classes, in which case it will throw. */ - virtual std::vector> output_shapes( - const std::vector& inputs); + virtual std::vector output_shapes(const std::vector& inputs); virtual ~Primitive() = default; Primitive(const Primitive& other) = delete; @@ -220,6 +219,7 @@ class Arange : public UnaryPrimitive { DEFINE_PRINT(Arange) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; private: double start_; @@ -386,8 +386,7 @@ class ArgReduce : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArgReduce) bool is_equivalent(const Primitive& other) const override; - std::vector> output_shapes( - const std::vector& inputs) override; + std::vector output_shapes(const std::vector& inputs) override; private: ReduceType reduce_type_; @@ -437,11 +436,7 @@ class AsType : public UnaryPrimitive { class AsStrided : public UnaryPrimitive { public: - explicit AsStrided( - Stream stream, - std::vector shape, - std::vector strides, - size_t offset) + explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset) : UnaryPrimitive(stream), shape_(std::move(shape)), strides_(std::move(strides)), @@ -455,8 +450,8 @@ class AsStrided : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; private: - std::vector shape_; - std::vector strides_; + Shape shape_; + Strides strides_; size_t offset_; void eval(const std::vector& inputs, array& out); @@ -527,7 +522,7 @@ class GatherMM : public UnaryPrimitive { class Broadcast : public UnaryPrimitive { public: - explicit Broadcast(Stream stream, const std::vector& shape) + explicit Broadcast(Stream stream, const Shape& shape) : UnaryPrimitive(stream), shape_(shape) {} void eval_cpu(const std::vector& inputs, array& out) override; @@ -539,7 +534,7 @@ class Broadcast : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; private: - std::vector shape_; + Shape shape_; void eval(const std::vector& inputs, array& out); }; @@ -586,8 +581,7 @@ class Compiled : public Primitive { DEFINE_VMAP() DEFINE_GRADS() - std::vector> output_shapes( - const std::vector& inputs) override; + std::vector output_shapes(const std::vector& inputs) override; void print(std::ostream& os) override; bool is_equivalent(const Primitive& other) const override; @@ -616,6 +610,7 @@ class Concatenate : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Concatenate) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; private: int axis_; @@ -853,8 +848,7 @@ class DivMod : public Primitive { DEFINE_GRADS() DEFINE_PRINT(DivMod) DEFINE_DEFAULT_IS_EQUIVALENT() - std::vector> output_shapes( - const std::vector& inputs) override { + std::vector output_shapes(const std::vector& inputs) override { return std::vector{inputs[0].shape(), inputs[0].shape()}; } @@ -1063,6 +1057,7 @@ class Gather : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Gather) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; private: void eval(const std::vector& inputs, array& out); @@ -1339,6 +1334,7 @@ class Matmul : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override; }; class Maximum : public UnaryPrimitive { @@ -1444,8 +1440,7 @@ class NumberOfElements : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(NumberOfElements) bool is_equivalent(const Primitive& other) const override; - std::vector> output_shapes( - const std::vector& inputs) override { + std::vector output_shapes(const std::vector& inputs) override { return {{}}; } @@ -1542,6 +1537,7 @@ class QuantizedMatmul : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(QuantizedMatmul) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; private: int group_size_; @@ -1577,7 +1573,7 @@ class GatherQMM : public UnaryPrimitive { class RandomBits : public UnaryPrimitive { public: - explicit RandomBits(Stream stream, const std::vector& shape, int width) + explicit RandomBits(Stream stream, const Shape& shape, int width) : UnaryPrimitive(stream), shape_(shape), width_(width) {} void eval_cpu(const std::vector& inputs, array& out) override; @@ -1588,7 +1584,7 @@ class RandomBits : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; private: - std::vector shape_; + Shape shape_; int width_; void eval(const std::vector& inputs, array& out); @@ -1610,7 +1606,7 @@ class Real : public UnaryPrimitive { class Reshape : public UnaryPrimitive { public: - explicit Reshape(Stream stream, const std::vector& shape) + explicit Reshape(Stream stream, const Shape& shape) : UnaryPrimitive(stream), shape_(shape) {} void eval_cpu(const std::vector& inputs, array& out) override; @@ -1622,16 +1618,16 @@ class Reshape : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; private: - std::vector shape_; + Shape shape_; void eval(const std::vector& inputs, array& out); - std::pair> prepare_reshape( + static std::pair prepare_reshape( const array& in, const array& out); - void shared_buffer_reshape( + static void shared_buffer_reshape( const array& in, - const std::vector& out_strides, + const Strides& out_strides, array& out); }; @@ -1656,8 +1652,7 @@ class Reduce : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - std::vector> output_shapes( - const std::vector& inputs) override; + std::vector output_shapes(const std::vector& inputs) override; void print(std::ostream& os) override { switch (reduce_type_) { @@ -2141,6 +2136,7 @@ class Transpose : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Transpose) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; private: std::vector axes_; @@ -2230,24 +2226,9 @@ class Eigh : public Primitive { DEFINE_VMAP() DEFINE_PRINT(Eigh) - std::vector> output_shapes( - const std::vector& inputs) override { - auto shape = inputs[0].shape(); - shape.pop_back(); // Remove last dimension for eigenvalues - if (compute_eigenvectors_) { - return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors - } else { - return {shape}; // Only eigenvalues - } - } + std::vector output_shapes(const std::vector& inputs) override; - bool is_equivalent(const Primitive& other) const override { - if (auto* p = dynamic_cast(&other)) { - return uplo_ == p->uplo_ && - compute_eigenvectors_ == p->compute_eigenvectors_; - } - return false; - } + bool is_equivalent(const Primitive& other) const override; private: void eval(const std::vector& inputs, std::vector& outputs);