2024-02-17 11:16:39 +08:00
|
|
|
#include "mlx/primitives.h"
|
|
|
|
|
|
|
|
namespace mlx::core::fast {
|
|
|
|
|
|
|
|
// Custom primitive accepts a fallback function which it uses for
|
2024-02-17 22:54:32 +08:00
|
|
|
// transformations. Transformations are virtual so that derived classes may
|
|
|
|
// override the default behavior.
|
2024-02-17 11:16:39 +08:00
|
|
|
class Custom : public Primitive {
|
|
|
|
public:
|
|
|
|
explicit Custom(
|
|
|
|
Stream stream,
|
|
|
|
std::function<std::vector<array>(std::vector<array>)> fallback)
|
|
|
|
: Primitive(stream), fallback_(fallback){};
|
|
|
|
|
|
|
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
|
|
const std::vector<array>& inputs,
|
|
|
|
const std::vector<int>& axes) override;
|
|
|
|
|
|
|
|
virtual std::vector<array> jvp(
|
|
|
|
const std::vector<array>& primals,
|
|
|
|
const std::vector<array>& tangents,
|
|
|
|
const std::vector<int>& argnums) override;
|
|
|
|
|
|
|
|
virtual std::vector<array> vjp(
|
|
|
|
const std::vector<array>& primals,
|
|
|
|
const std::vector<array>& cotangents,
|
|
|
|
const std::vector<int>& argnums,
|
|
|
|
const std::vector<array>& outputs) override;
|
|
|
|
|
|
|
|
private:
|
|
|
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
|
|
};
|
|
|
|
|
|
|
|
class RoPE : public Custom {
|
|
|
|
public:
|
|
|
|
RoPE(
|
|
|
|
Stream stream,
|
|
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
|
|
int dims,
|
|
|
|
bool traditional,
|
|
|
|
float base,
|
|
|
|
float scale,
|
|
|
|
int offset)
|
|
|
|
: Custom(stream, fallback),
|
|
|
|
dims_(dims),
|
|
|
|
traditional_(traditional),
|
|
|
|
base_(base),
|
|
|
|
scale_(scale),
|
|
|
|
offset_(offset){};
|
|
|
|
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
|
|
override;
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
|
|
override;
|
|
|
|
|
|
|
|
DEFINE_PRINT(RoPE)
|
|
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
|
|
|
|
|
|
private:
|
|
|
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
|
|
int dims_;
|
|
|
|
bool traditional_;
|
|
|
|
float base_;
|
|
|
|
float scale_;
|
|
|
|
int offset_;
|
|
|
|
};
|
|
|
|
|
2024-03-05 13:06:11 +08:00
|
|
|
class ScaledDotProductAttention : public Custom {
|
|
|
|
public:
|
|
|
|
explicit ScaledDotProductAttention(
|
|
|
|
Stream stream,
|
|
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
|
|
const float scale,
|
|
|
|
const bool needs_mask)
|
|
|
|
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};
|
|
|
|
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
|
|
override {
|
|
|
|
outputs[0] = fallback_(inputs)[0];
|
|
|
|
};
|
|
|
|
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
|
|
override {
|
|
|
|
eval_gpu(inputs, outputs[0]);
|
|
|
|
};
|
|
|
|
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, array& out);
|
|
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
|
|
|
|
|
|
DEFINE_PRINT(ScaledDotProductAttention)
|
|
|
|
|
|
|
|
private:
|
|
|
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
|
|
float scale_;
|
|
|
|
bool needs_mask_;
|
|
|
|
};
|
|
|
|
|
2024-02-17 11:16:39 +08:00
|
|
|
} // namespace mlx::core::fast
|