mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
428 lines
11 KiB
C++
428 lines
11 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#include <optional>
|
|
#include <variant>
|
|
|
|
#include "mlx/primitives.h"
|
|
|
|
namespace mlx::core::fast {
|
|
|
|
// Custom primitive accepts a fallback function which it uses for
|
|
// transformations. Transformations are virtual so that derived classes may
|
|
// override the default behavior.
|
|
class Custom : public Primitive {
|
|
public:
|
|
explicit Custom(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback)
|
|
: Primitive(stream), fallback_(std::move(fallback)) {}
|
|
|
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
const std::vector<array>& inputs,
|
|
const std::vector<int>& axes) override;
|
|
|
|
virtual std::vector<array> jvp(
|
|
const std::vector<array>& primals,
|
|
const std::vector<array>& tangents,
|
|
const std::vector<int>& argnums) override;
|
|
|
|
virtual std::vector<array> vjp(
|
|
const std::vector<array>& primals,
|
|
const std::vector<array>& cotangents,
|
|
const std::vector<int>& argnums,
|
|
const std::vector<array>& outputs) override;
|
|
|
|
protected:
|
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
};
|
|
|
|
class RMSNorm : public Custom {
|
|
public:
|
|
RMSNorm(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
float eps)
|
|
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
|
|
static bool use_fallback(Stream stream);
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("NYI");
|
|
}
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
std::vector<array> vjp(
|
|
const std::vector<array>& primals,
|
|
const std::vector<array>& cotangents,
|
|
const std::vector<int>& argnums,
|
|
const std::vector<array>& outputs) override;
|
|
|
|
DEFINE_NAME(RMSNorm)
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
|
|
auto state() const {
|
|
return std::make_pair(nullptr, eps_);
|
|
}
|
|
|
|
private:
|
|
float eps_;
|
|
};
|
|
|
|
class RMSNormVJP : public Custom {
|
|
public:
|
|
RMSNormVJP(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
float eps)
|
|
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("NYI");
|
|
}
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
DEFINE_NAME(RMSNormVJP)
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
auto state() const {
|
|
return std::make_pair(nullptr, eps_);
|
|
}
|
|
|
|
private:
|
|
float eps_;
|
|
};
|
|
|
|
class LayerNorm : public Custom {
|
|
public:
|
|
LayerNorm(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
float eps)
|
|
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
|
|
static bool use_fallback(Stream s);
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("NYI");
|
|
}
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
std::vector<array> vjp(
|
|
const std::vector<array>& primals,
|
|
const std::vector<array>& cotangents,
|
|
const std::vector<int>& argnums,
|
|
const std::vector<array>& outputs) override;
|
|
|
|
DEFINE_NAME(LayerNorm)
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
auto state() const {
|
|
return std::make_pair(nullptr, eps_);
|
|
}
|
|
|
|
private:
|
|
float eps_;
|
|
};
|
|
|
|
class LayerNormVJP : public Custom {
|
|
public:
|
|
LayerNormVJP(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
float eps)
|
|
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("NYI");
|
|
}
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
DEFINE_NAME(LayerNormVJP)
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
auto state() const {
|
|
return std::make_pair(nullptr, eps_);
|
|
}
|
|
|
|
private:
|
|
float eps_;
|
|
};
|
|
|
|
class RoPE : public Custom {
|
|
public:
|
|
RoPE(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
int dims,
|
|
bool traditional,
|
|
float base,
|
|
float scale,
|
|
bool forward)
|
|
: Custom(stream, std::move(fallback)),
|
|
dims_(dims),
|
|
traditional_(traditional),
|
|
base_(base),
|
|
scale_(scale),
|
|
forward_(forward) {}
|
|
|
|
static bool use_fallback(Stream s);
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("NYI");
|
|
}
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
std::vector<array> vjp(
|
|
const std::vector<array>& primals,
|
|
const std::vector<array>& cotangents,
|
|
const std::vector<int>& argnums,
|
|
const std::vector<array>& outputs) override;
|
|
|
|
DEFINE_NAME(RoPE)
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
auto state() const {
|
|
return std::make_tuple(
|
|
nullptr, dims_, traditional_, base_, scale_, forward_);
|
|
}
|
|
|
|
private:
|
|
int dims_;
|
|
bool traditional_;
|
|
float base_;
|
|
float scale_;
|
|
bool forward_;
|
|
};
|
|
|
|
class ScaledDotProductAttention : public Custom {
|
|
public:
|
|
ScaledDotProductAttention(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
float scale,
|
|
bool do_causal,
|
|
bool has_sinks,
|
|
bool output_logsumexp)
|
|
: Custom(stream, std::move(fallback)),
|
|
scale_(scale),
|
|
do_causal_(do_causal),
|
|
has_sinks_(has_sinks),
|
|
output_logsumexp_(output_logsumexp) {}
|
|
|
|
static bool use_fallback(
|
|
const array& q,
|
|
const array& k,
|
|
const array& v,
|
|
bool has_mask,
|
|
bool has_arr_mask,
|
|
bool do_causal,
|
|
bool is_training,
|
|
bool output_logsumexp,
|
|
Stream s);
|
|
static bool supports_bool_mask();
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("NYI");
|
|
}
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
std::vector<array> vjp(
|
|
const std::vector<array>& primals,
|
|
const std::vector<array>& cotangents,
|
|
const std::vector<int>& argnums,
|
|
const std::vector<array>& outputs) override;
|
|
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
|
|
DEFINE_NAME(ScaledDotProductAttention);
|
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
auto state() const {
|
|
return std::make_tuple(
|
|
nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_);
|
|
}
|
|
|
|
private:
|
|
float scale_;
|
|
bool do_causal_;
|
|
bool has_sinks_;
|
|
bool output_logsumexp_;
|
|
};
|
|
|
|
class ScaledDotProductAttentionVJP : public Custom {
|
|
public:
|
|
ScaledDotProductAttentionVJP(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
float scale,
|
|
bool do_causal,
|
|
bool has_sinks)
|
|
: Custom(stream, std::move(fallback)),
|
|
scale_(scale),
|
|
do_causal_(do_causal),
|
|
has_sinks_(has_sinks) {}
|
|
|
|
static bool use_fallback(const array& q, Stream s);
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("NYI");
|
|
}
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
DEFINE_NAME(ScaledDotProductAttentionVJP);
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
auto state() const {
|
|
return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
|
|
}
|
|
|
|
private:
|
|
float scale_;
|
|
bool do_causal_;
|
|
bool has_sinks_;
|
|
};
|
|
|
|
class ConvertFP8 : public Primitive {
|
|
public:
|
|
explicit ConvertFP8(Stream stream, bool to_fp8)
|
|
: Primitive(stream), to_fp8_(to_fp8) {}
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
const char* name() const override {
|
|
if (to_fp8_) {
|
|
return "ToFP8";
|
|
} else {
|
|
return "FromFP8";
|
|
}
|
|
}
|
|
bool state() const {
|
|
return to_fp8_;
|
|
};
|
|
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
DEFINE_INPUT_OUTPUT_SHAPE();
|
|
|
|
private:
|
|
bool to_fp8_;
|
|
};
|
|
|
|
class Quantize : public Custom {
|
|
public:
|
|
explicit Quantize(
|
|
Stream stream,
|
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
int group_size,
|
|
int bits,
|
|
QuantizationMode mode,
|
|
bool dequantize)
|
|
: Custom(stream, std::move(fallback)),
|
|
group_size_(group_size),
|
|
bits_(bits),
|
|
mode_(mode),
|
|
dequantize_(dequantize) {}
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
DEFINE_NAME(Quantize);
|
|
|
|
bool is_equivalent(const Primitive& other) const override;
|
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
|
auto state() const {
|
|
return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_);
|
|
}
|
|
|
|
private:
|
|
int group_size_;
|
|
int bits_;
|
|
QuantizationMode mode_;
|
|
bool dequantize_;
|
|
};
|
|
|
|
using ScalarArg = std::variant<bool, int, float>;
|
|
|
|
class CustomKernel : public Primitive {
|
|
public:
|
|
CustomKernel(
|
|
Stream stream,
|
|
std::string name,
|
|
std::string source,
|
|
std::tuple<int, int, int> grid,
|
|
std::tuple<int, int, int> threadgroup,
|
|
std::vector<std::tuple<bool, bool, bool>> shape_infos,
|
|
bool ensure_row_contiguous,
|
|
std::optional<float> init_value,
|
|
std::vector<ScalarArg> scalar_arguments,
|
|
bool is_precompiled,
|
|
int shared_memory)
|
|
: Primitive(stream),
|
|
name_(std::move(name)),
|
|
source_(std::move(source)),
|
|
grid_(grid),
|
|
threadgroup_(threadgroup),
|
|
shape_infos_(std::move(shape_infos)),
|
|
ensure_row_contiguous_(ensure_row_contiguous),
|
|
init_value_(init_value),
|
|
scalar_arguments_(std::move(scalar_arguments)),
|
|
is_precompiled_(is_precompiled),
|
|
shared_memory_(shared_memory) {}
|
|
|
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override {
|
|
throw std::runtime_error("Custom kernels only run on GPU.");
|
|
}
|
|
|
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
override;
|
|
|
|
DEFINE_NAME(CustomKernel);
|
|
auto state() const {
|
|
return std::make_tuple(
|
|
name_,
|
|
source_,
|
|
grid_,
|
|
threadgroup_,
|
|
shape_infos_,
|
|
ensure_row_contiguous_,
|
|
init_value_,
|
|
scalar_arguments_,
|
|
is_precompiled_,
|
|
shared_memory_);
|
|
}
|
|
|
|
private:
|
|
std::string name_;
|
|
std::string source_;
|
|
std::tuple<int, int, int> grid_;
|
|
std::tuple<int, int, int> threadgroup_;
|
|
std::vector<std::tuple<bool, bool, bool>> shape_infos_;
|
|
bool ensure_row_contiguous_;
|
|
std::optional<float> init_value_;
|
|
std::vector<ScalarArg> scalar_arguments_;
|
|
bool is_precompiled_;
|
|
int shared_memory_;
|
|
};
|
|
|
|
} // namespace mlx::core::fast
|