mlx/mlx/primitives.h

1537 lines
40 KiB
C
Raw Normal View History

2023-12-01 03:12:53 +08:00
// Copyright © 2023 Apple Inc.
2023-11-30 02:30:41 +08:00
#pragma once
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/io/load.h"
#include "mlx/stream.h"
#define DEFINE_VMAP() \
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
const std::vector<array>& inputs, const std::vector<int>& axes) \
override;
2023-11-30 02:30:41 +08:00
#define DEFINE_GRADS() \
std::vector<array> jvp( \
2023-11-30 02:30:41 +08:00
const std::vector<array>& primals, \
const std::vector<array>& tangents, \
const std::vector<int>& argnums) override; \
\
std::vector<array> vjp( \
const std::vector<array>& primals, \
const std::vector<array>& cotangents, \
2023-11-30 02:30:41 +08:00
const std::vector<int>& argnums) override;
#define DEFINE_PRINT(PRIMITIVE) \
void print(std::ostream& os) override { \
os << #PRIMITIVE; \
}
#define DEFINE_DEFAULT_IS_EQUIVALENT() \
bool is_equivalent(const Primitive& other) const override { \
return true; \
}
namespace mlx::core {
// Abstract base class
class Primitive {
public:
explicit Primitive(Stream stream) : stream_(stream) {}
/** The device the primitive will run on. */
const Device& device() {
return stream().device;
}
/** The stream the primitive will run on. */
const Stream& stream() {
return stream_;
}
/**
* A primitive must know how to evaluate itself on
* the CPU/GPU for the given inputs and populate the output arrays.
2023-11-30 02:30:41 +08:00
*
Spelling (#342) * spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-02 13:08:17 +08:00
* To avoid unnecessary allocations, the evaluation function
2023-11-30 02:30:41 +08:00
* is responsible for allocating space for the array.
*/
virtual void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) = 0;
virtual void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) = 0;
2023-11-30 02:30:41 +08:00
/**
* The Jacobian-vector product.
*/
virtual std::vector<array> jvp(
2023-11-30 02:30:41 +08:00
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums);
/**
* The vector-Jacobian product.
*/
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
2023-11-30 02:30:41 +08:00
const std::vector<int>& argnums);
/**
2023-12-16 05:46:50 +08:00
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the output arrays
* representing the vectorized computation and the axes which
* corresponds to the vectorized dimensions of each output.
2023-11-30 02:30:41 +08:00
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
2023-11-30 02:30:41 +08:00
const std::vector<array>& inputs,
const std::vector<int>& axes);
/** Print the primitive. */
virtual void print(std::ostream& os) = 0;
Spelling (#342) * spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-02 13:08:17 +08:00
/** Equivalence check defaults to false unless overridden by the primitive */
2023-11-30 02:30:41 +08:00
virtual bool is_equivalent(const Primitive& other) const {
return false;
}
virtual ~Primitive() = default;
Primitive(const Primitive& other) = delete;
Primitive(Primitive&& other) = delete;
Primitive& operator=(const Primitive& other) = delete;
Primitive& operator=(Primitive&& other) = delete;
private:
// Every primitive stores the stream it should run in
Stream stream_;
};
class UnaryPrimitive : public Primitive {
/**
* An abstract base class for a primitive with a single output.
*/
2023-11-30 02:30:41 +08:00
public:
explicit UnaryPrimitive(Stream stream) : Primitive(stream) {}
2023-11-30 02:30:41 +08:00
virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
2023-11-30 02:30:41 +08:00
inline void eval_cpu(
2023-11-30 02:30:41 +08:00
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
eval_cpu(inputs, outputs[0]);
}
inline void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override {
eval_gpu(inputs, outputs[0]);
}
2023-11-30 02:30:41 +08:00
virtual ~UnaryPrimitive() = default;
UnaryPrimitive(const UnaryPrimitive& other) = delete;
UnaryPrimitive(UnaryPrimitive&& other) = delete;
UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
};
class Abs : public UnaryPrimitive {
public:
explicit Abs(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Abs)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Add : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Add(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Add)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Arange : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Arange(Stream stream, double start, double stop, double step)
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(Arange)
bool is_equivalent(const Primitive& other) const override;
private:
double start_;
double stop_;
double step_;
void eval(const std::vector<array>& inputs, array& out);
};
class ArcCos : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArcCos(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(ArcCos)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class ArcCosh : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(ArcCosh)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class ArcSin : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArcSin(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(ArcSin)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class ArcSinh : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(ArcSinh)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class ArcTan : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArcTan(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(ArcTan)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class ArcTanh : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(ArcTanh)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class ArgPartition : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArgPartition(Stream stream, int kth, int axis)
: UnaryPrimitive(stream), kth_(kth), axis_(axis){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_PRINT(ArgPartition)
bool is_equivalent(const Primitive& other) const override;
private:
int kth_;
int axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class ArgReduce : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
enum ReduceType {
ArgMin,
ArgMax,
};
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;
private:
ReduceType reduce_type_;
int axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class ArgSort : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ArgSort(Stream stream, int axis)
: UnaryPrimitive(stream), axis_(axis){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_PRINT(ArgSort)
bool is_equivalent(const Primitive& other) const override;
private:
int axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class AsType : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit AsType(Stream stream, Dtype dtype)
: UnaryPrimitive(stream), dtype_(dtype){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(AsType)
bool is_equivalent(const Primitive& other) const override;
private:
Dtype dtype_;
void eval(const std::vector<array>& inputs, array& out);
};
class AsStrided : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit AsStrided(
Stream stream,
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t offset)
: UnaryPrimitive(stream),
shape_(shape),
strides_(strides),
offset_(offset){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_GRADS()
DEFINE_PRINT(AsStrided)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
std::vector<size_t> strides_;
size_t offset_;
void eval(const std::vector<array>& inputs, array& out);
};
class Broadcast : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Broadcast(Stream stream, const std::vector<int>& shape)
: UnaryPrimitive(stream), shape_(shape){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Broadcast)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
void eval(const std::vector<array>& inputs, array& out);
};
class Ceil : public UnaryPrimitive {
public:
explicit Ceil(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Ceil)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Concatenate : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Concatenate(Stream stream, int axis)
: UnaryPrimitive(stream), axis_(axis){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Concatenate)
bool is_equivalent(const Primitive& other) const override;
private:
int axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class Convolution : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Convolution(
Stream stream,
const std::vector<int>& padding,
const std::vector<int>& kernel_strides,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation)
: UnaryPrimitive(stream),
2023-11-30 02:30:41 +08:00
padding_(padding),
kernel_strides_(kernel_strides),
kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
2023-11-30 02:30:41 +08:00
const std::vector<int>& argnums) override;
DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> padding_;
std::vector<int> kernel_strides_;
std::vector<int> kernel_dilation_;
std::vector<int> input_dilation_;
void eval(const std::vector<array>& inputs, array& out);
};
class Copy : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Copy(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Copy)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Cos : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Cos(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Cos)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Cosh : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Cosh(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Cosh)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Divide : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Divide(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Divide)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class DivMod : public Primitive {
public:
explicit DivMod(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(DivMod)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
class Remainder : public UnaryPrimitive {
public:
explicit Remainder(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Remainder)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Equal : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Equal(Stream stream, bool equal_nan = false)
: UnaryPrimitive(stream), equal_nan_(equal_nan){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Equal)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
bool equal_nan_;
};
class Erf : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Erf(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Erf)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class ErfInv : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit ErfInv(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(ErfInv)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Exp : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Exp(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Exp)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class FFT : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit FFT(
Stream stream,
const std::vector<size_t>& axes,
bool inverse,
bool real)
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(FFT)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<size_t> axes_;
bool inverse_;
bool real_;
void eval(const std::vector<array>& inputs, array& out);
};
class Floor : public UnaryPrimitive {
public:
explicit Floor(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Floor)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Full : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Full(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Full)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Gather : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Gather(
Stream stream,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes)
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Gather)
bool is_equivalent(const Primitive& other) const override;
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
std::vector<int> slice_sizes_;
};
class Greater : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Greater(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Greater)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class GreaterEqual : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(GreaterEqual)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Less : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Less(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Less)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class LessEqual : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit LessEqual(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(LessEqual)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Load : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Load(
Stream stream,
std::shared_ptr<io::Reader> reader,
size_t offset,
bool swap_endianness = false)
: UnaryPrimitive(stream),
2023-11-30 02:30:41 +08:00
reader_(reader),
offset_(offset),
swap_endianness_(swap_endianness){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(Load)
private:
void eval(const std::vector<array>& inputs, array& out);
std::shared_ptr<io::Reader> reader_;
size_t offset_;
bool swap_endianness_;
};
class Log : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
enum Base { two, ten, e };
explicit Log(Stream stream, Base base)
: UnaryPrimitive(stream), base_(base){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Log)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
Base base_;
void eval(const std::vector<array>& inputs, array& out);
};
class Log1p : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Log1p(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Log1p)
private:
void eval(const std::vector<array>& inputs, array& out);
};
class LogicalNot : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit LogicalNot(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(LogicalNot)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class LogicalAnd : public UnaryPrimitive {
public:
explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(LogicalAnd)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class LogicalOr : public UnaryPrimitive {
public:
explicit LogicalOr(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(LogicalOr)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class LogAddExp : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit LogAddExp(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(LogAddExp)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Matmul : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Matmul(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
2023-11-30 02:30:41 +08:00
const std::vector<int>& argnums) override;
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()
};
class Maximum : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Maximum(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Maximum)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Minimum : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Minimum(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Minimum)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Multiply : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Multiply(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Multiply)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Negative : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Negative(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Negative)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class NotEqual : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit NotEqual(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(NotEqual)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Pad : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Pad(
Stream stream,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size)
: UnaryPrimitive(stream),
2023-11-30 02:30:41 +08:00
axes_(axes),
low_pad_size_(low_pad_size),
high_pad_size_(high_pad_size){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Pad)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> axes_;
std::vector<int> low_pad_size_;
std::vector<int> high_pad_size_;
void eval(const std::vector<array>& inputs, array& out);
};
class Partition : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Partition(Stream stream, int kth, int axis)
: UnaryPrimitive(stream), kth_(kth), axis_(axis){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Partition)
bool is_equivalent(const Primitive& other) const override;
private:
int kth_;
int axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class Power : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Power(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Power)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class QuantizedMatmul : public UnaryPrimitive {
public:
explicit QuantizedMatmul(
Stream stream,
int group_size,
int bits,
bool transpose)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(QuantizedMatmul)
bool is_equivalent(const Primitive& other) const override;
private:
int group_size_;
int bits_;
bool transpose_;
void eval(const std::vector<array>& inputs, array& out);
};
class RandomBits : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
: UnaryPrimitive(stream), shape_(shape), width_(width){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_PRINT(RandomBits)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
int width_;
void eval(const std::vector<array>& inputs, array& out);
};
class Reshape : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Reshape(Stream stream, const std::vector<int>& shape)
: UnaryPrimitive(stream), shape_(shape){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Reshape)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> shape_;
void eval(const std::vector<array>& inputs, array& out);
};
class Reduce : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
enum ReduceType { And, Or, Sum, Prod, Min, Max };
explicit Reduce(
Stream stream,
ReduceType reduce_type,
const std::vector<int>& axes)
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
2023-11-30 02:30:41 +08:00
const std::vector<int>& argnums) override;
void print(std::ostream& os) override {
switch (reduce_type_) {
case And:
os << "And";
case Or:
os << "And";
break;
case Sum:
os << "Sum";
break;
case Prod:
os << "Prod";
break;
case Min:
os << "Min";
break;
case Max:
os << "Max";
break;
}
os << " Reduce";
}
bool is_equivalent(const Primitive& other) const override;
private:
ReduceType reduce_type_;
std::vector<int> axes_;
void eval(const std::vector<array>& inputs, array& out);
};
class Round : public UnaryPrimitive {
2023-12-19 03:32:48 +08:00
public:
explicit Round(Stream stream) : UnaryPrimitive(stream){};
2023-12-19 03:32:48 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-12-19 03:32:48 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Round)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Scan : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
enum ReduceType { Max, Min, Sum, Prod };
explicit Scan(
Stream stream,
ReduceType reduce_type,
int axis,
bool reverse,
bool inclusive)
: UnaryPrimitive(stream),
2023-11-30 02:30:41 +08:00
reduce_type_(reduce_type),
axis_(axis),
reverse_(reverse),
inclusive_(inclusive){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS();
2023-11-30 02:30:41 +08:00
void print(std::ostream& os) override {
os << "Cum";
switch (reduce_type_) {
case Sum:
os << "Sum";
break;
case Prod:
os << "Prod";
break;
case Min:
os << "Min";
break;
case Max:
os << "Max";
break;
}
os << " Reduce";
}
bool is_equivalent(const Primitive& other) const override;
private:
ReduceType reduce_type_;
int axis_;
bool reverse_;
bool inclusive_;
void eval(const std::vector<array>& inputs, array& out);
};
class Scatter : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
enum ReduceType { Max, Min, Sum, Prod, None };
explicit Scatter(
Stream stream,
ReduceType reduce_type,
const std::vector<int>& axes)
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(Scatter)
bool is_equivalent(const Primitive& other) const override;
private:
void eval(const std::vector<array>& inputs, array& out);
ReduceType reduce_type_;
std::vector<int> axes_;
};
class Sigmoid : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Sigmoid)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Sign : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Sign(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Sign)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Sin : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Sin(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Sin)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Sinh : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Sinh(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Sinh)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Slice : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Slice(
Stream stream,
const std::vector<int>& start_indices,
const std::vector<int>& end_indices,
const std::vector<int>& strides)
: UnaryPrimitive(stream),
2023-11-30 02:30:41 +08:00
start_indices_(start_indices),
end_indices_(end_indices),
strides_(strides){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Slice)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> start_indices_;
std::vector<int> end_indices_;
std::vector<int> strides_;
void eval(const std::vector<array>& inputs, array& out);
};
class Softmax : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Softmax(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Softmax)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Sort : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Sort(Stream stream, int axis)
: UnaryPrimitive(stream), axis_(axis){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Sort)
bool is_equivalent(const Primitive& other) const override;
private:
int axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class Square : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Square(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Square)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Sqrt : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Sqrt(Stream stream, bool recip = false)
: UnaryPrimitive(stream), recip_(recip){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Sqrt)
bool is_equivalent(const Primitive& other) const override;
private:
void eval(const std::vector<array>& inputs, array& out);
bool recip_;
};
class StopGradient : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit StopGradient(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_PRINT(StopGradient)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Subtract : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Subtract(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Subtract)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Tan : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Tan(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Tan)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Tanh : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Tanh(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Tanh)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Uniform : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Uniform(Stream stream) : UnaryPrimitive(stream){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_PRINT(Uniform)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Transpose : public UnaryPrimitive {
2023-11-30 02:30:41 +08:00
public:
explicit Transpose(Stream stream, const std::vector<int>& axes)
: UnaryPrimitive(stream), axes_(axes){};
2023-11-30 02:30:41 +08:00
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
2023-11-30 02:30:41 +08:00
DEFINE_GRADS()
DEFINE_PRINT(Transpose)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> axes_;
void eval(const std::vector<array>& inputs, array& out);
};
} // namespace mlx::core