mlx/mlx/fast_primitives.h
Jagrit Digani 9adcd1a650
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00

319 lines
8.6 KiB
C++

// Copyright © 2024 Apple Inc.
#include <optional>
#include "mlx/primitives.h"
namespace mlx::core::fast {
// Custom primitive accepts a fallback function which it uses for
// transformations. Transformations are virtual so that derived classes may
// override the default behavior.
class Custom : public Primitive {
public:
explicit Custom(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback) {}
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
virtual std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
class RMSNorm : public Custom {
public:
RMSNorm(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RMSNorm)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class RMSNormVJP : public Custom {
public:
RMSNormVJP(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RMSNormVJP)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class LayerNorm : public Custom {
public:
LayerNorm(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(LayerNorm)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class LayerNormVJP : public Custom {
public:
LayerNormVJP(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(LayerNormVJP)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_pair(nullptr, eps_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class RoPE : public Custom {
public:
RoPE(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int dims,
bool traditional,
float base,
float scale,
bool forward)
: Custom(stream, fallback),
dims_(dims),
traditional_(traditional),
base_(base),
scale_(scale),
forward_(forward) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(
nullptr, dims_, traditional_, base_, scale_, forward_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int dims_;
bool traditional_;
float base_;
float scale_;
bool forward_;
};
class ScaledDotProductAttention : public Custom {
public:
explicit ScaledDotProductAttention(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool do_causal)
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
eval_gpu(inputs, outputs[0]);
}
void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override;
DEFINE_PRINT(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(nullptr, scale_, do_causal_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_;
bool do_causal_;
};
class AffineQuantize : public Custom {
public:
explicit AffineQuantize(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int group_size,
int bits,
bool dequantize)
: Custom(stream, fallback),
group_size_(group_size),
bits_(bits),
dequantize_(dequantize) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(AffineQuantize);
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(nullptr, group_size_, bits_, dequantize_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int group_size_;
int bits_;
bool dequantize_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;
bool ndim = false;
};
class CustomKernel : public Primitive {
public:
CustomKernel(
Stream stream,
std::string name,
std::string source,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous,
std::optional<float> init_value)
: Primitive(stream),
source_(std::move(source)),
name_(std::move(name)),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(std::move(shape_infos)),
ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("Custom Metal kernels only run on GPU.");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(CustomKernel);
private:
std::string source_;
std::string name_;
std::tuple<int, int, int> grid_;
std::tuple<int, int, int> threadgroup_;
std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_;
std::optional<float> init_value_;
};
} // namespace mlx::core::fast