16      std::function<std::vector<array>(std::vector<array>)> fallback)
 
 
   19  virtual std::pair<std::vector<array>, std::vector<int>> 
vmap(
 
   20      const std::vector<array>& inputs,
 
   21      const std::vector<int>& axes) 
override;
 
   23  virtual std::vector<array> 
jvp(
 
   24      const std::vector<array>& primals,
 
   25      const std::vector<array>& tangents,
 
   26      const std::vector<int>& argnums) 
override;
 
   28  virtual std::vector<array> 
vjp(
 
   29      const std::vector<array>& primals,
 
   30      const std::vector<array>& cotangents,
 
   31      const std::vector<int>& argnums,
 
   32      const std::vector<array>& outputs) 
override;
 
   35  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
   42      std::function<std::vector<array>(std::vector<array>)> fallback,
 
 
   46  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
   48    throw std::runtime_error(
"NYI");
 
 
   50  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
   53  std::vector<array> 
vjp(
 
   54      const std::vector<array>& primals,
 
   55      const std::vector<array>& cotangents,
 
   56      const std::vector<int>& argnums,
 
   57      const std::vector<array>& outputs) 
override;
 
   63  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
   71      std::function<std::vector<array>(std::vector<array>)> fallback,
 
 
   75  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
   77    throw std::runtime_error(
"NYI");
 
 
   79  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
   86  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
   94      std::function<std::vector<array>(std::vector<array>)> fallback,
 
 
   98  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  100    throw std::runtime_error(
"NYI");
 
 
  102  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  106      const std::vector<array>& primals,
 
  107      const std::vector<array>& cotangents,
 
  108      const std::vector<int>& argnums,
 
  109      const std::vector<array>& outputs) 
override;
 
  115  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
  123      std::function<std::vector<array>(std::vector<array>)> fallback,
 
 
  127  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  129    throw std::runtime_error(
"NYI");
 
 
  131  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  138  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
  146      std::function<std::vector<array>(std::vector<array>)> fallback,
 
  155        traditional_(traditional),
 
 
  161  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  163    throw std::runtime_error(
"NYI");
 
 
  165  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  169      const std::vector<array>& primals,
 
  170      const std::vector<array>& cotangents,
 
  171      const std::vector<int>& argnums,
 
  172      const std::vector<array>& outputs) 
override;
 
  178  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
  191      std::function<std::vector<array>(std::vector<array>)> fallback,
 
  193      const bool needs_mask)
 
  194      : 
Custom(
stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
 
 
  196  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  198    throw std::runtime_error(
"NYI");
 
 
  201  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
 
  212  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
  221      std::function<std::vector<array>(std::vector<array>)> fallback,
 
  226        group_size_(group_size),
 
 
  230  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  232    throw std::runtime_error(
"NYI");
 
 
  235  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  241  std::function<std::vector<array>(std::vector<array>)> fallback_;
 
 
  259      std::tuple<int, int, int> grid,
 
  260      std::tuple<int, int, int> threadgroup,
 
  261      std::vector<CustomKernelShapeInfo> shape_infos,
 
  262      bool ensure_row_contiguous,
 
  263      std::optional<float> init_value)
 
  265        source_(
std::move(source)),
 
  266        name_(
std::move(name)),
 
  268        threadgroup_(threadgroup),
 
  269        shape_infos_(
std::move(shape_infos)),
 
  270        ensure_row_contiguous_(ensure_row_contiguous),
 
  271        init_value_(init_value) {}
 
 
  273  void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  275    throw std::runtime_error(
"Custom Metal kernels only run on GPU.");
 
 
  278  void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
 
  286  std::tuple<int, int, int> grid_;
 
  287  std::tuple<int, int, int> threadgroup_;
 
  288  std::vector<CustomKernelShapeInfo> shape_infos_;
 
  289  bool ensure_row_contiguous_;
 
  290  std::optional<float> init_value_;
 
 
Definition primitives.h:48
 
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
 
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
 
Definition fast_primitives.h:217
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:230
 
DEFINE_PRINT(AffineQuantize)
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
AffineQuantize(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int group_size, int bits, bool dequantize)
Definition fast_primitives.h:219
 
Definition fast_primitives.h:12
 
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
Definition fast_primitives.h:14
 
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
 
Definition fast_primitives.h:253
 
DEFINE_PRINT(CustomKernel)
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:273
 
CustomKernel(Stream stream, std::string name, std::string source, std::tuple< int, int, int > grid, std::tuple< int, int, int > threadgroup, std::vector< CustomKernelShapeInfo > shape_infos, bool ensure_row_contiguous, std::optional< float > init_value)
Definition fast_primitives.h:255
 
Definition fast_primitives.h:90
 
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const override
 
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:92
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:98
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
Definition fast_primitives.h:119
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:127
 
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:121
 
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const override
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
Definition fast_primitives.h:38
 
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:40
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:46
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const override
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
Definition fast_primitives.h:67
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const override
 
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
Definition fast_primitives.h:69
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:75
 
Definition fast_primitives.h:142
 
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)
Definition fast_primitives.h:144
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:161
 
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const override
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
 
Definition fast_primitives.h:187
 
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition fast_primitives.h:201
 
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)
Definition fast_primitives.h:189
 
DEFINE_PRINT(ScaledDotProductAttention)
 
void eval_gpu(const std::vector< array > &inputs, array &out)
 
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition fast_primitives.h:196
 
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
 
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
 
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
 
Definition fast_primitives.h:247
 
bool strides
Definition fast_primitives.h:249
 
bool shape
Definition fast_primitives.h:248
 
bool ndim
Definition fast_primitives.h:250