Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-19 21:43:54 -08:00
committed by GitHub
parent d0fda82595
commit 5798256fcf
14 changed files with 645 additions and 113 deletions

View File

@@ -13,7 +13,7 @@
namespace mlx::core {
constexpr int max_compile_depth = 10;
constexpr int max_compile_depth = 11;
bool is_unary(const Primitive& p) {
return (
@@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) {
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
}
bool is_reduction(const Primitive& p) {
return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
}
namespace detail {
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs);
} // namespace detail
bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
}
Compiled::Compiled(
Stream stream,
@@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) {
}
}
std::vector<std::vector<int>> Compiled::output_shapes(
const std::vector<array>& inputs) {
size_t nd = 0;
for (auto& in : inputs) {
nd = std::max(nd, in.ndim());
}
std::vector<int> out_shape(nd, 0);
for (auto& in : inputs) {
auto dd = nd - in.ndim();
for (auto i = dd; i < nd; ++i) {
out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
}
}
// All outputs have the same shape
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
}
namespace detail {
CompileMode& compile_mode() {
@@ -180,21 +198,30 @@ struct CompilerCache {
std::vector<array> outputs;
std::vector<array> tape;
bool empty{true};
std::vector<uint64_t> constants;
};
// Returns a reference to a CacheEntry which can be updated
// by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
CacheEntry& find(
size_t fun_id,
const std::vector<array>& inputs,
bool shapeless,
const std::vector<uint64_t>& constants) {
// Try to find the entry
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = entry_it->second;
auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) {
auto is_match = [shapeless](
const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
return false;
}
for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) {
if (in1[i].ndim() != in2[i].ndim()) {
return false;
}
if (!shapeless && in1[i].shape() != in2[i].shape()) {
return false;
}
if (in1[i].dtype() != in2[i].dtype()) {
@@ -210,7 +237,7 @@ struct CompilerCache {
// more easily searchable structure.
for (auto& entry : entries) {
// Check the inputs match and return if so
if (is_match(inputs, entry.inputs)) {
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
return entry;
}
}
@@ -651,7 +678,8 @@ std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
const std::vector<array>& inputs,
bool shapeless) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
@@ -669,18 +697,29 @@ std::vector<array> compile_replace(
real_inputs.push_back(trace_to_real.at(in.id()));
}
if (a.siblings().empty()) {
auto shape =
shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();
auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
std::move(shape),
a.dtype(),
a.primitive_ptr(),
std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
} else {
// Ensure the order is correct for multi-output primitives
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
auto trace_out = a.outputs();
for (auto& o : trace_out) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
std::vector<std::vector<int>> shapes;
if (shapeless) {
shapes = a.primitive().output_shapes(real_inputs);
} else {
for (auto& o : trace_out) {
shapes.push_back(o.shape());
}
}
auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) {
@@ -697,13 +736,34 @@ std::vector<array> compile_replace(
return outputs;
}
void compile_validate_shapeless(const std::vector<array>& tape) {
for (auto& t : tape) {
if (!t.has_primitive()) {
continue;
}
auto& p = t.primitive();
if (allows_shapeless(p)) {
continue;
}
std::ostringstream msg;
msg << "[compile] Cannot compile primitive ";
p.print(msg);
msg << " with shapeless enabled.";
throw std::invalid_argument(msg.str());
}
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) {
size_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (compile_mode() == CompileMode::disabled) {
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
return [fun, fun_id, shapeless, constants = std::move(constants)](
const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer();
@@ -712,12 +772,14 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
}
// Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs);
auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);
// No matching cache entry existed, so compile
if (entry.empty) {
// Mark the entry as not empty since we are about to fill it
entry.empty = false;
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
@@ -739,11 +801,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
if (compile_mode() != CompileMode::no_fuse) {
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
}
if (shapeless) {
compile_validate_shapeless(entry.tape);
}
}
// At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
return compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
};
}
@@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) {
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
bool shapeless /* false */) {
if (detail::compile_mode() == CompileMode::disabled) {
return fun;
}
auto fun_id = detail::getAddress(fun);
return detail::compile(fun, fun_id);
return detail::compile(fun, fun_id, shapeless);
}
void disable_compile() {