Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-19 21:43:54 -08:00
committed by GitHub
parent d0fda82595
commit 5798256fcf
14 changed files with 645 additions and 113 deletions

View File

@@ -37,7 +37,7 @@ std::string build_lib_name(
os << "C";
print_constant(constant_hasher, x);
} else {
os << ((x.size() == 1) ? "S" : "V");
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
@@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) {
}
}
inline bool is_scalar(const array& x) {
return x.size() == 1;
};
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
@@ -358,7 +354,7 @@ void Compiled::eval_cpu(
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (x.size() == 1) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
@@ -385,7 +381,7 @@ void Compiled::eval_cpu(
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || x.size() <= 1) {
if (contiguous || is_scalar(x)) {
continue;
}
@@ -458,7 +454,7 @@ void Compiled::eval_cpu(
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && in.size() > 1 && in.is_donatable() &&
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}

View File

@@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) {
void print_constant(std::ostream& os, const array& x);
inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
} // namespace mlx::core

View File

@@ -31,9 +31,6 @@ inline void build_kernel(
return constant_ids.find(x.id()) != constant_ids.end();
};
// For scalar we shouldn't do the indexing things, just read at 0
auto is_scalar = [](const array& x) { return x.size() == 1; };
NodeNamer namer;
bool add_indices = false;
int cnt = 0;
@@ -226,8 +223,7 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ true);
kernel_source_ = kernel.str();
lib = d.get_library(kernel_lib_, kernel_source_);
lib = d.get_library(kernel_lib_, kernel.str());
}
// Figure out which kernel we are using
@@ -235,7 +231,7 @@ void Compiled::eval_gpu(
bool contiguous = true;
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
x.size() > 1) {
!is_scalar(x)) {
contiguous = false;
break;
}
@@ -256,7 +252,7 @@ void Compiled::eval_gpu(
auto& x = inputs[i];
// Skip scalar inputs.
if (x.size() <= 1) {
if (is_scalar(x)) {
continue;
}
@@ -311,7 +307,7 @@ void Compiled::eval_gpu(
}
auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++);
if (!contiguous && x.size() > 1) {
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),

View File

@@ -13,7 +13,7 @@
namespace mlx::core {
constexpr int max_compile_depth = 10;
constexpr int max_compile_depth = 11;
bool is_unary(const Primitive& p) {
return (
@@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) {
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
}
bool is_reduction(const Primitive& p) {
return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
}
namespace detail {
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs);
} // namespace detail
bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
}
Compiled::Compiled(
Stream stream,
@@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) {
}
}
std::vector<std::vector<int>> Compiled::output_shapes(
const std::vector<array>& inputs) {
size_t nd = 0;
for (auto& in : inputs) {
nd = std::max(nd, in.ndim());
}
std::vector<int> out_shape(nd, 0);
for (auto& in : inputs) {
auto dd = nd - in.ndim();
for (auto i = dd; i < nd; ++i) {
out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
}
}
// All outputs have the same shape
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
}
namespace detail {
CompileMode& compile_mode() {
@@ -180,21 +198,30 @@ struct CompilerCache {
std::vector<array> outputs;
std::vector<array> tape;
bool empty{true};
std::vector<uint64_t> constants;
};
// Returns a reference to a CacheEntry which can be updated
// by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
CacheEntry& find(
size_t fun_id,
const std::vector<array>& inputs,
bool shapeless,
const std::vector<uint64_t>& constants) {
// Try to find the entry
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = entry_it->second;
auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) {
auto is_match = [shapeless](
const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
return false;
}
for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) {
if (in1[i].ndim() != in2[i].ndim()) {
return false;
}
if (!shapeless && in1[i].shape() != in2[i].shape()) {
return false;
}
if (in1[i].dtype() != in2[i].dtype()) {
@@ -210,7 +237,7 @@ struct CompilerCache {
// more easily searchable structure.
for (auto& entry : entries) {
// Check the inputs match and return if so
if (is_match(inputs, entry.inputs)) {
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
return entry;
}
}
@@ -651,7 +678,8 @@ std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
const std::vector<array>& inputs,
bool shapeless) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
@@ -669,18 +697,29 @@ std::vector<array> compile_replace(
real_inputs.push_back(trace_to_real.at(in.id()));
}
if (a.siblings().empty()) {
auto shape =
shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();
auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
std::move(shape),
a.dtype(),
a.primitive_ptr(),
std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
} else {
// Ensure the order is correct for multi-output primitives
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
auto trace_out = a.outputs();
for (auto& o : trace_out) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
std::vector<std::vector<int>> shapes;
if (shapeless) {
shapes = a.primitive().output_shapes(real_inputs);
} else {
for (auto& o : trace_out) {
shapes.push_back(o.shape());
}
}
auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) {
@@ -697,13 +736,34 @@ std::vector<array> compile_replace(
return outputs;
}
void compile_validate_shapeless(const std::vector<array>& tape) {
for (auto& t : tape) {
if (!t.has_primitive()) {
continue;
}
auto& p = t.primitive();
if (allows_shapeless(p)) {
continue;
}
std::ostringstream msg;
msg << "[compile] Cannot compile primitive ";
p.print(msg);
msg << " with shapeless enabled.";
throw std::invalid_argument(msg.str());
}
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) {
size_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (compile_mode() == CompileMode::disabled) {
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
return [fun, fun_id, shapeless, constants = std::move(constants)](
const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer();
@@ -712,12 +772,14 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
}
// Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs);
auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);
// No matching cache entry existed, so compile
if (entry.empty) {
// Mark the entry as not empty since we are about to fill it
entry.empty = false;
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
@@ -739,11 +801,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
if (compile_mode() != CompileMode::no_fuse) {
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
}
if (shapeless) {
compile_validate_shapeless(entry.tape);
}
}
// At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
return compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
};
}
@@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) {
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
bool shapeless /* false */) {
if (detail::compile_mode() == CompileMode::disabled) {
return fun;
}
auto fun_id = detail::getAddress(fun);
return detail::compile(fun, fun_id);
return detail::compile(fun, fun_id, shapeless);
}
void disable_compile() {

View File

@@ -8,9 +8,10 @@ namespace mlx::core {
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
// Compile takes a function and returns a new function
/** Compile takes a function and returns a compiled function. */
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
bool shapeless = false);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also

View File

@@ -71,6 +71,15 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
throw std::invalid_argument("Primitive's vmap not implemented.");
};
std::vector<std::vector<int>> Primitive::output_shapes(
const std::vector<array>&) {
std::ostringstream msg;
msg << "[Primitive::output_shapes] ";
this->print(msg);
msg << " cannot infer output shapes.";
throw std::invalid_argument(msg.str());
};
std::vector<array> Abs::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -383,6 +392,13 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
}
std::vector<std::vector<int>> ArgReduce::output_shapes(
const std::vector<array>& inputs) {
auto out_shape = inputs[0].shape();
out_shape[axis_] = 1;
return {out_shape};
}
bool ArgSort::is_equivalent(const Primitive& other) const {
const ArgSort& r_other = static_cast<const ArgSort&>(other);
return axis_ == r_other.axis_;
@@ -2202,6 +2218,15 @@ bool Reduce::is_equivalent(const Primitive& other) const {
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
}
std::vector<std::vector<int>> Reduce::output_shapes(
const std::vector<array>& inputs) {
std::vector<int> out_shape = inputs[0].shape();
for (auto i : axes_) {
out_shape[i] = 1;
}
return {out_shape};
}
std::vector<array> Round::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -36,6 +36,12 @@
return true; \
}
#define DEFINE_INPUT_OUTPUT_SHAPE() \
std::vector<std::vector<int>> output_shapes( \
const std::vector<array>& inputs) override { \
return {inputs[0].shape()}; \
};
namespace mlx::core {
// Abstract base class
@@ -102,6 +108,11 @@ class Primitive {
return false;
}
/** Get the output shapes of the primitive. This is not required to be
* implemented by derived classes, in which case it will throw. */
virtual std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs);
virtual ~Primitive() = default;
Primitive(const Primitive& other) = delete;
Primitive(Primitive&& other) = delete;
@@ -152,6 +163,7 @@ class Abs : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Abs)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -168,6 +180,7 @@ class Add : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Add)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -226,6 +239,7 @@ class ArcCos : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ArcCos)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -242,6 +256,7 @@ class ArcCosh : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ArcCosh)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -258,6 +273,7 @@ class ArcSin : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ArcSin)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -274,6 +290,7 @@ class ArcSinh : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ArcSinh)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -290,6 +307,7 @@ class ArcTan : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ArcTan)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -306,6 +324,7 @@ class ArcTanh : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ArcTanh)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -321,6 +340,7 @@ class ArgPartition : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(ArgPartition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
@@ -346,6 +366,8 @@ class ArgReduce : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
private:
ReduceType reduce_type_;
@@ -364,6 +386,7 @@ class ArgSort : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(ArgSort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
@@ -383,6 +406,7 @@ class AsType : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(AsType)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
@@ -448,6 +472,7 @@ class Ceil : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Ceil)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -478,15 +503,14 @@ class Compiled : public Primitive {
DEFINE_VMAP()
DEFINE_GRADS()
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override;
std::string metal_lib_name() const {
std::string lib_name() const {
return kernel_lib_;
}
std::string metal_lib_source() const {
return kernel_source_;
}
private:
const std::vector<array> inputs_;
@@ -495,7 +519,6 @@ class Compiled : public Primitive {
const std::unordered_set<uintptr_t> constant_ids_;
std::string kernel_lib_;
std::string kernel_source_;
};
class Concatenate : public UnaryPrimitive {
@@ -563,6 +586,7 @@ class Copy : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Copy)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -579,6 +603,7 @@ class Cos : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Cos)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -595,6 +620,7 @@ class Cosh : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Cosh)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -665,6 +691,7 @@ class Divide : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Divide)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -683,6 +710,10 @@ class DivMod : public Primitive {
DEFINE_GRADS()
DEFINE_PRINT(DivMod)
DEFINE_DEFAULT_IS_EQUIVALENT()
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
return std::vector{inputs[0].shape(), inputs[0].shape()};
};
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
@@ -699,6 +730,7 @@ class Remainder : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Remainder)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -715,6 +747,7 @@ class Equal : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
void print(std::ostream& os) override {
if (equal_nan_) {
@@ -740,6 +773,7 @@ class Erf : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Erf)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -756,6 +790,7 @@ class ErfInv : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(ErfInv)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -772,6 +807,7 @@ class Exp : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Exp)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -814,6 +850,7 @@ class Floor : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Floor)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -868,6 +905,7 @@ class Greater : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Greater)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -884,6 +922,7 @@ class GreaterEqual : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(GreaterEqual)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -900,6 +939,7 @@ class Less : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Less)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -916,6 +956,7 @@ class LessEqual : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(LessEqual)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -958,6 +999,7 @@ class Log : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
void print(std::ostream& os) override {
switch (base_) {
@@ -988,6 +1030,7 @@ class Log1p : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Log1p)
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1004,6 +1047,7 @@ class LogicalNot : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(LogicalNot)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1020,6 +1064,7 @@ class LogicalAnd : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(LogicalAnd)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1036,6 +1081,7 @@ class LogicalOr : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(LogicalOr)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1052,6 +1098,7 @@ class LogAddExp : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(LogAddExp)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1085,6 +1132,7 @@ class Maximum : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Maximum)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1101,6 +1149,7 @@ class Minimum : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Minimum)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1117,6 +1166,7 @@ class Multiply : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Multiply)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1133,6 +1183,7 @@ class Negative : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Negative)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1149,6 +1200,7 @@ class NotEqual : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(NotEqual)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1193,6 +1245,7 @@ class Partition : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Partition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
@@ -1213,6 +1266,7 @@ class Power : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Power)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1305,6 +1359,9 @@ class Reduce : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
void print(std::ostream& os) override {
switch (reduce_type_) {
case And:
@@ -1347,6 +1404,7 @@ class Round : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Round)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1455,6 +1513,7 @@ class Sigmoid : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Sigmoid)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1471,6 +1530,7 @@ class Sign : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Sign)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1487,6 +1547,7 @@ class Sin : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Sin)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1503,6 +1564,7 @@ class Sinh : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Sinh)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1547,6 +1609,7 @@ class Softmax : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Softmax)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1563,6 +1626,7 @@ class Sort : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Sort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
@@ -1604,6 +1668,7 @@ class Square : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Square)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1619,6 +1684,7 @@ class Sqrt : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override {
@@ -1644,6 +1710,7 @@ class StopGradient : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_PRINT(StopGradient)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1660,6 +1727,7 @@ class Subtract : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Subtract)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1676,6 +1744,7 @@ class Tan : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Tan)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
@@ -1692,6 +1761,7 @@ class Tanh : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Tanh)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);

View File

@@ -18,7 +18,9 @@ std::vector<array> vmap_replace(
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id);
size_t fun_id,
bool shapeless = false,
std::vector<uint64_t> constants = {});
// Erase cached compile functions
void compile_erase(size_t fun_id);