// Copyright © 2024 Apple Inc. #include "mlx/primitives.h" namespace mlx::core::fast { // Custom primitive accepts a fallback function which it uses for // transformations. Transformations are virtual so that derived classes may // override the default behavior. class Custom : public Primitive { public: explicit Custom( Stream stream, std::function(std::vector)> fallback) : Primitive(stream), fallback_(fallback) {}; virtual std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; virtual std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; virtual std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; private: std::function(std::vector)> fallback_; }; class RMSNorm : public Custom { public: RMSNorm( Stream stream, std::function(std::vector)> fallback, float eps) : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); }; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive& other) const override; private: std::function(std::vector)> fallback_; float eps_; }; class RMSNormVJP : public Custom { public: RMSNormVJP( Stream stream, std::function(std::vector)> fallback, float eps) : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); }; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive& other) const override; private: std::function(std::vector)> fallback_; float eps_; }; class LayerNorm : public Custom { public: LayerNorm( Stream stream, std::function(std::vector)> fallback, float eps) : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); }; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive& other) const override; private: std::function(std::vector)> fallback_; float eps_; }; class LayerNormVJP : public Custom { public: LayerNormVJP( Stream stream, std::function(std::vector)> fallback, float eps) : Custom(stream, fallback), eps_(eps) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); }; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive& other) const override; private: std::function(std::vector)> fallback_; float eps_; }; class RoPE : public Custom { public: RoPE( Stream stream, std::function(std::vector)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward) : Custom(stream, fallback), dims_(dims), traditional_(traditional), base_(base), scale_(scale), offset_(offset), forward_(forward) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); }; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive& other) const override; private: std::function(std::vector)> fallback_; int dims_; bool traditional_; float base_; float scale_; int offset_; bool forward_; }; class ScaledDotProductAttention : public Custom { public: explicit ScaledDotProductAttention( Stream stream, std::function(std::vector)> fallback, const float scale, const bool needs_mask) : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}; void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); }; void eval_gpu(const std::vector& inputs, std::vector& outputs) override { eval_gpu(inputs, outputs[0]); }; void eval_gpu(const std::vector& inputs, array& out); bool is_equivalent(const Primitive& other) const override; DEFINE_PRINT(ScaledDotProductAttention); private: std::function(std::vector)> fallback_; float scale_; bool needs_mask_; }; } // namespace mlx::core::fast