mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -37,7 +37,7 @@ std::string build_lib_name(
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << ((x.size() == 1) ? "S" : "V");
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
@@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_scalar(const array& x) {
|
||||
return x.size() == 1;
|
||||
};
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
@@ -358,7 +354,7 @@ void Compiled::eval_cpu(
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (auto& x : inputs) {
|
||||
if (x.size() == 1) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
@@ -385,7 +381,7 @@ void Compiled::eval_cpu(
|
||||
auto& x = inputs[i];
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || x.size() <= 1) {
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -458,7 +454,7 @@ void Compiled::eval_cpu(
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().contiguous && in.size() > 1 && in.is_donatable() &&
|
||||
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
|
@@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) {
|
||||
|
||||
void print_constant(std::ostream& os, const array& x);
|
||||
|
||||
inline bool is_scalar(const array& x) {
|
||||
return x.ndim() == 0;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -31,9 +31,6 @@ inline void build_kernel(
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
// For scalar we shouldn't do the indexing things, just read at 0
|
||||
auto is_scalar = [](const array& x) { return x.size() == 1; };
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
@@ -226,8 +223,7 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
|
||||
kernel_source_ = kernel.str();
|
||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
||||
lib = d.get_library(kernel_lib_, kernel.str());
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
@@ -235,7 +231,7 @@ void Compiled::eval_gpu(
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
x.size() > 1) {
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
@@ -256,7 +252,7 @@ void Compiled::eval_gpu(
|
||||
auto& x = inputs[i];
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (x.size() <= 1) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -311,7 +307,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
if (!contiguous && x.size() > 1) {
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
|
118
mlx/compile.cpp
118
mlx/compile.cpp
@@ -13,7 +13,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int max_compile_depth = 10;
|
||||
constexpr int max_compile_depth = 11;
|
||||
|
||||
bool is_unary(const Primitive& p) {
|
||||
return (
|
||||
@@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) {
|
||||
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
||||
}
|
||||
|
||||
bool is_reduction(const Primitive& p) {
|
||||
return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
} // namespace detail
|
||||
bool allows_shapeless(const Primitive& p) {
|
||||
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
|
||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
Stream stream,
|
||||
@@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Compiled::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
size_t nd = 0;
|
||||
for (auto& in : inputs) {
|
||||
nd = std::max(nd, in.ndim());
|
||||
}
|
||||
std::vector<int> out_shape(nd, 0);
|
||||
for (auto& in : inputs) {
|
||||
auto dd = nd - in.ndim();
|
||||
for (auto i = dd; i < nd; ++i) {
|
||||
out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
|
||||
}
|
||||
}
|
||||
// All outputs have the same shape
|
||||
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
CompileMode& compile_mode() {
|
||||
@@ -180,21 +198,30 @@ struct CompilerCache {
|
||||
std::vector<array> outputs;
|
||||
std::vector<array> tape;
|
||||
bool empty{true};
|
||||
std::vector<uint64_t> constants;
|
||||
};
|
||||
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
// by the caller to avoid copying large tapes / inputs / outputs
|
||||
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
|
||||
CacheEntry& find(
|
||||
size_t fun_id,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless,
|
||||
const std::vector<uint64_t>& constants) {
|
||||
// Try to find the entry
|
||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||
auto& entries = entry_it->second;
|
||||
auto is_match = [](const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
auto is_match = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].shape() != in2[i].shape()) {
|
||||
if (in1[i].ndim() != in2[i].ndim()) {
|
||||
return false;
|
||||
}
|
||||
if (!shapeless && in1[i].shape() != in2[i].shape()) {
|
||||
return false;
|
||||
}
|
||||
if (in1[i].dtype() != in2[i].dtype()) {
|
||||
@@ -210,7 +237,7 @@ struct CompilerCache {
|
||||
// more easily searchable structure.
|
||||
for (auto& entry : entries) {
|
||||
// Check the inputs match and return if so
|
||||
if (is_match(inputs, entry.inputs)) {
|
||||
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
@@ -651,7 +678,8 @@ std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs) {
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless) {
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
@@ -669,18 +697,29 @@ std::vector<array> compile_replace(
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
if (a.siblings().empty()) {
|
||||
auto shape =
|
||||
shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();
|
||||
auto real_a = array(
|
||||
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||
std::move(shape),
|
||||
a.dtype(),
|
||||
a.primitive_ptr(),
|
||||
std::move(real_inputs));
|
||||
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||
} else {
|
||||
// Ensure the order is correct for multi-output primitives
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
auto trace_out = a.outputs();
|
||||
for (auto& o : trace_out) {
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
std::vector<std::vector<int>> shapes;
|
||||
if (shapeless) {
|
||||
shapes = a.primitive().output_shapes(real_inputs);
|
||||
} else {
|
||||
for (auto& o : trace_out) {
|
||||
shapes.push_back(o.shape());
|
||||
}
|
||||
}
|
||||
auto real_out =
|
||||
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||
for (int i = 0; i < trace_out.size(); ++i) {
|
||||
@@ -697,13 +736,34 @@ std::vector<array> compile_replace(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
for (auto& t : tape) {
|
||||
if (!t.has_primitive()) {
|
||||
continue;
|
||||
}
|
||||
auto& p = t.primitive();
|
||||
if (allows_shapeless(p)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Cannot compile primitive ";
|
||||
p.print(msg);
|
||||
msg << " with shapeless enabled.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id) {
|
||||
size_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||
return [fun, fun_id, shapeless, constants = std::move(constants)](
|
||||
const std::vector<array>& inputs) {
|
||||
// If the inputs are tracers, trace the original graph
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||
return in.is_tracer();
|
||||
@@ -712,12 +772,14 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
}
|
||||
|
||||
// Find a cache entry with the correct inputs
|
||||
auto& entry = compiler_cache().find(fun_id, inputs);
|
||||
auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);
|
||||
|
||||
// No matching cache entry existed, so compile
|
||||
if (entry.empty) {
|
||||
// Mark the entry as not empty since we are about to fill it
|
||||
entry.empty = false;
|
||||
// Set the constants
|
||||
entry.constants = std::move(constants);
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||
|
||||
@@ -739,11 +801,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (compile_mode() != CompileMode::no_fuse) {
|
||||
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||
}
|
||||
|
||||
if (shapeless) {
|
||||
compile_validate_shapeless(entry.tape);
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
// with real arrays that can be evaluated
|
||||
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
|
||||
return compile_replace(
|
||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) {
|
||||
} // namespace detail
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
bool shapeless /* false */) {
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
return detail::compile(fun, fun_id);
|
||||
return detail::compile(fun, fun_id, shapeless);
|
||||
}
|
||||
|
||||
void disable_compile() {
|
||||
|
@@ -8,9 +8,10 @@ namespace mlx::core {
|
||||
|
||||
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
||||
|
||||
// Compile takes a function and returns a new function
|
||||
/** Compile takes a function and returns a compiled function. */
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
bool shapeless = false);
|
||||
|
||||
/** Globally disable compilation.
|
||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||
|
@@ -71,6 +71,15 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
throw std::invalid_argument("Primitive's vmap not implemented.");
|
||||
};
|
||||
|
||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
const std::vector<array>&) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::output_shapes] ";
|
||||
this->print(msg);
|
||||
msg << " cannot infer output shapes.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::vector<array> Abs::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -383,6 +392,13 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
||||
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> ArgReduce::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
auto out_shape = inputs[0].shape();
|
||||
out_shape[axis_] = 1;
|
||||
return {out_shape};
|
||||
}
|
||||
|
||||
bool ArgSort::is_equivalent(const Primitive& other) const {
|
||||
const ArgSort& r_other = static_cast<const ArgSort&>(other);
|
||||
return axis_ == r_other.axis_;
|
||||
@@ -2202,6 +2218,15 @@ bool Reduce::is_equivalent(const Primitive& other) const {
|
||||
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Reduce::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<int> out_shape = inputs[0].shape();
|
||||
for (auto i : axes_) {
|
||||
out_shape[i] = 1;
|
||||
}
|
||||
return {out_shape};
|
||||
}
|
||||
|
||||
std::vector<array> Round::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@@ -36,6 +36,12 @@
|
||||
return true; \
|
||||
}
|
||||
|
||||
#define DEFINE_INPUT_OUTPUT_SHAPE() \
|
||||
std::vector<std::vector<int>> output_shapes( \
|
||||
const std::vector<array>& inputs) override { \
|
||||
return {inputs[0].shape()}; \
|
||||
};
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Abstract base class
|
||||
@@ -102,6 +108,11 @@ class Primitive {
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Get the output shapes of the primitive. This is not required to be
|
||||
* implemented by derived classes, in which case it will throw. */
|
||||
virtual std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
virtual ~Primitive() = default;
|
||||
Primitive(const Primitive& other) = delete;
|
||||
Primitive(Primitive&& other) = delete;
|
||||
@@ -152,6 +163,7 @@ class Abs : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Abs)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -168,6 +180,7 @@ class Add : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Add)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -226,6 +239,7 @@ class ArcCos : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcCos)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -242,6 +256,7 @@ class ArcCosh : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcCosh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -258,6 +273,7 @@ class ArcSin : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcSin)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -274,6 +290,7 @@ class ArcSinh : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcSinh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -290,6 +307,7 @@ class ArcTan : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcTan)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -306,6 +324,7 @@ class ArcTanh : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcTanh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -321,6 +340,7 @@ class ArgPartition : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(ArgPartition)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
@@ -346,6 +366,8 @@ class ArgReduce : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(ArgReduce)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
@@ -364,6 +386,7 @@ class ArgSort : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(ArgSort)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
@@ -383,6 +406,7 @@ class AsType : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(AsType)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
@@ -448,6 +472,7 @@ class Ceil : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Ceil)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -478,15 +503,14 @@ class Compiled : public Primitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override;
|
||||
void print(std::ostream& os) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
std::string metal_lib_name() const {
|
||||
std::string lib_name() const {
|
||||
return kernel_lib_;
|
||||
}
|
||||
std::string metal_lib_source() const {
|
||||
return kernel_source_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<array> inputs_;
|
||||
@@ -495,7 +519,6 @@ class Compiled : public Primitive {
|
||||
const std::unordered_set<uintptr_t> constant_ids_;
|
||||
|
||||
std::string kernel_lib_;
|
||||
std::string kernel_source_;
|
||||
};
|
||||
|
||||
class Concatenate : public UnaryPrimitive {
|
||||
@@ -563,6 +586,7 @@ class Copy : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Copy)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -579,6 +603,7 @@ class Cos : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Cos)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -595,6 +620,7 @@ class Cosh : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Cosh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -665,6 +691,7 @@ class Divide : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Divide)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -683,6 +710,10 @@ class DivMod : public Primitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(DivMod)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override {
|
||||
return std::vector{inputs[0].shape(), inputs[0].shape()};
|
||||
};
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
@@ -699,6 +730,7 @@ class Remainder : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Remainder)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -715,6 +747,7 @@ class Equal : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
if (equal_nan_) {
|
||||
@@ -740,6 +773,7 @@ class Erf : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Erf)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -756,6 +790,7 @@ class ErfInv : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ErfInv)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -772,6 +807,7 @@ class Exp : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Exp)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -814,6 +850,7 @@ class Floor : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Floor)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -868,6 +905,7 @@ class Greater : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Greater)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -884,6 +922,7 @@ class GreaterEqual : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(GreaterEqual)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -900,6 +939,7 @@ class Less : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Less)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -916,6 +956,7 @@ class LessEqual : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LessEqual)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -958,6 +999,7 @@ class Log : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (base_) {
|
||||
@@ -988,6 +1030,7 @@ class Log1p : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Log1p)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1004,6 +1047,7 @@ class LogicalNot : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogicalNot)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1020,6 +1064,7 @@ class LogicalAnd : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogicalAnd)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1036,6 +1081,7 @@ class LogicalOr : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogicalOr)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1052,6 +1098,7 @@ class LogAddExp : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogAddExp)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1085,6 +1132,7 @@ class Maximum : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Maximum)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1101,6 +1149,7 @@ class Minimum : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Minimum)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1117,6 +1166,7 @@ class Multiply : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Multiply)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1133,6 +1183,7 @@ class Negative : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Negative)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1149,6 +1200,7 @@ class NotEqual : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(NotEqual)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1193,6 +1245,7 @@ class Partition : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Partition)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
@@ -1213,6 +1266,7 @@ class Power : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Power)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1305,6 +1359,9 @@ class Reduce : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (reduce_type_) {
|
||||
case And:
|
||||
@@ -1347,6 +1404,7 @@ class Round : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Round)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1455,6 +1513,7 @@ class Sigmoid : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Sigmoid)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1471,6 +1530,7 @@ class Sign : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Sign)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1487,6 +1547,7 @@ class Sin : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Sin)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1503,6 +1564,7 @@ class Sinh : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Sinh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1547,6 +1609,7 @@ class Softmax : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Softmax)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1563,6 +1626,7 @@ class Sort : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Sort)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
@@ -1604,6 +1668,7 @@ class Square : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Square)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1619,6 +1684,7 @@ class Sqrt : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
@@ -1644,6 +1710,7 @@ class StopGradient : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(StopGradient)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1660,6 +1727,7 @@ class Subtract : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Subtract)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1676,6 +1744,7 @@ class Tan : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Tan)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1692,6 +1761,7 @@ class Tanh : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Tanh)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
@@ -18,7 +18,9 @@ std::vector<array> vmap_replace(
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id);
|
||||
size_t fun_id,
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(size_t fun_id);
|
||||
|
Reference in New Issue
Block a user