From c3965fc5ee665083714b391f249ec9029895f216 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 16 Feb 2024 19:16:39 -0800 Subject: [PATCH] Separate fast ops and primitives (#699) --- mlx/backend/common/rope.cpp | 3 +- mlx/backend/metal/rope.cpp | 3 +- mlx/backend/no_metal/primitives.cpp | 2 +- mlx/fast.cpp | 2 + mlx/fast.h | 66 +--------------------------- mlx/fast_primitives.h | 68 +++++++++++++++++++++++++++++ 6 files changed, 74 insertions(+), 70 deletions(-) create mode 100644 mlx/fast_primitives.h diff --git a/mlx/backend/common/rope.cpp b/mlx/backend/common/rope.cpp index c0c2bba8e..15b5de7e5 100644 --- a/mlx/backend/common/rope.cpp +++ b/mlx/backend/common/rope.cpp @@ -1,7 +1,6 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/fast.h" -#include "mlx/primitives.h" +#include "mlx/fast_primitives.h" namespace mlx::core::fast { diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 29295f3ac..fdea57985 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/utils.h" -#include "mlx/fast.h" -#include "mlx/primitives.h" +#include "mlx/fast_primitives.h" namespace mlx::core::fast { diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index bd4026e2c..8e66f56b3 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" -#include "mlx/fast.h" +#include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 96d4f03ce..ee28138f1 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,6 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/fast.h" +#include "mlx/fast_primitives.h" +#include "mlx/ops.h" #include "mlx/transforms.h" namespace mlx::core::fast { diff --git a/mlx/fast.h b/mlx/fast.h index 5deac0cdb..48ac90a5a 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -2,40 +2,10 @@ #pragma once -#include "mlx/ops.h" -#include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core::fast { -// Custom primitive accepts a fallback function which it uses for -// transformations. Transformations are virtual so that derived classes may to -// override the default behavior -class Custom : public Primitive { - public: - explicit Custom( - Stream stream, - std::function(std::vector)> fallback) - : Primitive(stream), fallback_(fallback){}; - - virtual std::pair, std::vector> vmap( - const std::vector& inputs, - const std::vector& axes) override; - - virtual std::vector jvp( - const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) override; - - virtual std::vector vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums, - const std::vector& outputs) override; - - private: - std::function(std::vector)> fallback_; -}; - array rope( const array& x, int dims, @@ -45,38 +15,4 @@ array rope( int offset, StreamOrDevice s /* = {} */); -class RoPE : public Custom { - public: - RoPE( - Stream stream, - std::function(std::vector)> fallback, - int dims, - bool traditional, - float base, - float scale, - int offset) - : Custom(stream, fallback), - dims_(dims), - traditional_(traditional), - base_(base), - scale_(scale), - offset_(offset){}; - - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override; - void eval_gpu(const std::vector& inputs, std::vector& outputs) - override; - - DEFINE_PRINT(RoPE) - bool is_equivalent(const Primitive& other) const override; - - private: - std::function(std::vector)> fallback_; - int dims_; - bool traditional_; - float base_; - float scale_; - int offset_; -}; - } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h new file mode 100644 index 000000000..acb5d0046 --- /dev/null +++ b/mlx/fast_primitives.h @@ -0,0 +1,68 @@ +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +// Custom primitive accepts a fallback function which it uses for +// transformations. Transformations are virtual so that derived classes may to +// override the default behavior +class Custom : public Primitive { + public: + explicit Custom( + Stream stream, + std::function(std::vector)> fallback) + : Primitive(stream), fallback_(fallback){}; + + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + private: + std::function(std::vector)> fallback_; +}; + +class RoPE : public Custom { + public: + RoPE( + Stream stream, + std::function(std::vector)> fallback, + int dims, + bool traditional, + float base, + float scale, + int offset) + : Custom(stream, fallback), + dims_(dims), + traditional_(traditional), + base_(base), + scale_(scale), + offset_(offset){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(RoPE) + bool is_equivalent(const Primitive& other) const override; + + private: + std::function(std::vector)> fallback_; + int dims_; + bool traditional_; + float base_; + float scale_; + int offset_; +}; + +} // namespace mlx::core::fast