Compile front-end (#476)

* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
This commit is contained in:
Awni Hannun
2024-01-26 13:45:30 -08:00
committed by GitHub
parent 874b739f3c
commit 8fa6b322b9
13 changed files with 1029 additions and 297 deletions

View File

@@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp

View File

@@ -47,6 +47,17 @@ array::array(
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
std::move(primitive),
std::move(inputs))) {}
std::vector<array> array::make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
@@ -158,7 +169,22 @@ array::ArrayDesc::ArrayDesc(
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(shape);
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}

View File

@@ -172,6 +172,12 @@ class array {
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
@@ -215,6 +221,11 @@ class array {
return *(array_desc_->primitive);
};
/** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive;
};
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
@@ -360,6 +371,12 @@ class array {
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
explicit ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
};
// The ArrayDesc contains the details of the materialized array including the

440
mlx/compile.cpp Normal file
View File

@@ -0,0 +1,440 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core {
namespace detail {
bool& compiler_disabled() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return true;
} else {
return false;
}
};
static bool compiler_disabled_ = get_val();
return compiler_disabled_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...);
fnType** fnPointer = f.template target<fnType*>();
if (fnPointer == nullptr) {
throw std::invalid_argument(
"[compile] Cannot compile a non-addressable function.");
}
return (size_t)*fnPointer;
}
struct CompilerCache {
struct CacheEntry {
std::vector<array> inputs;
std::vector<array> outputs;
std::vector<array> tape;
bool empty{true};
};
// Returns a reference to a CacheEntry which can be updated
// by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
// Try to find the entry
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = entry_it->second;
auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
throw std::runtime_error(
"[compiler] Got different number of inputs to function,"
" this should never happen.");
}
for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) {
return false;
}
if (in1[i].dtype() != in2[i].dtype()) {
return false;
}
}
return true;
};
// Loop over entries and check inputs match i.e. shapes and types must be
// equal. Note this could get really slow if one compiles the same
// function with many different shapes. May want to store entries in a
// more easily searchable structure.
for (auto& entry : entries) {
// Check the inputs match and return if so
if (is_match(inputs, entry.inputs)) {
return entry;
}
}
// Otherwise append a new cache entry
entries.push_back(CacheEntry{});
return entries.back();
};
void erase(size_t fun_id) {
cache_.erase(fun_id);
}
private:
CompilerCache() {
// Make sure the allocator is fully
// initialized before the compiler cache
allocator::allocator();
}
friend CompilerCache& compiler_cache();
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
};
CompilerCache& compiler_cache() {
static CompilerCache compiler_cache_;
return compiler_cache_;
}
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs) {
// Set the global tracing flag.
detail::InTracing in_tracing;
// Run the function on placeholder inputs
// to get compute graph
std::vector<array> tracer_inputs;
for (int i = 0; i < inputs.size(); ++i) {
array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});
in.set_tracer(true);
tracer_inputs.push_back(std::move(in));
}
return {tracer_inputs, fun(tracer_inputs)};
}
// Traverses the graph to build a tape and a map of array ids to their parents
std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::vector<array> tape;
std::unordered_set<std::uintptr_t> input_set;
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
for (int i = 0; i < inputs.size(); ++i) {
auto in = inputs[i];
input_set.insert(in.id());
}
// DFS the graph to build the tape, and log parents and scalars
std::unordered_set<std::uintptr_t> cache;
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (int i = 0; i < a.inputs().size(); i++) {
auto& in = a.inputs()[i];
parents_map[in.id()].push_back({a, i});
for (auto& s : a.siblings()) {
parents_map[in.id()].push_back({s, i});
}
// Don't recurse on inputs (but add them to the tape for the purpose
// of future optimizations)
if (input_set.find(a.id()) == input_set.end()) {
recurse(in);
}
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
tape.push_back(a);
};
for (auto& a : outputs) {
recurse(a);
}
return {tape, parents_map};
}
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
int passes) {
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
auto is_scalar = [](const array& a) {
return a.is_evaled() && a.ndim() == 0;
};
auto get_scalar_rep = [](const array& a) {
uint64_t v = 0;
int dtype;
switch (a.dtype().size) {
case 1:
v = *a.data<uint8_t>();
break;
case 4:
v = *a.data<uint32_t>();
break;
case 8:
v = *a.data<uint64_t>();
break;
}
return std::make_pair(v, a.dtype().val);
};
for (auto& a : tape) {
if (is_scalar(a)) {
scalars.insert({get_scalar_rep(a), a});
}
}
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
return false;
}
if (a.primitive_id() == b.primitive_id()) {
return false;
}
const auto& pa = a.primitive();
const auto& pb = b.primitive();
if (typeid(pa) != typeid(pb)) {
return false;
}
if (a.inputs().size() != b.inputs().size()) {
return false;
}
for (int i = 0; i < a.inputs().size(); i++) {
if (a.inputs()[i].id() != b.inputs()[i].id()) {
return false;
}
}
return pa.is_equivalent(pb);
};
// Pass 0: fuse scalars
std::vector<array> new_tape;
for (auto& arr : tape) {
// Check if we can fuse scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr);
// Don't keep orphaned scalars in the tape
continue;
}
}
new_tape.push_back(std::move(arr));
}
tape = std::move(new_tape);
std::unordered_set<uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) {
// Helper to check if we can fuse the parents of the
// given array
auto maybe_fuse_parents = [&](auto& a) {
auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) {
auto N = parents->second.size();
std::vector<bool> mask(N, false);
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
}
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src);
mask[j] = true;
}
}
}
// Erase orphaned parents so we don't keep fusing with them
for (int i = N - 1; i > 0; --i) {
if (mask[i]) {
parents->second.erase(parents->second.begin() + i);
}
}
return false;
} else {
return output_set.find(a.id()) == output_set.end();
}
};
bool discard = maybe_fuse_parents(arr);
for (auto& s : arr.siblings()) {
discard &= maybe_fuse_parents(s);
}
// If an array and its siblings have no parents, and none of them are
// outputs, it is safe to remove it from the tape
if (!discard) {
new_tape.push_back(std::move(arr));
}
}
tape = std::move(new_tape);
}
}
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
for (auto& a : tape) {
// Arrays in the tape without primitives are constants
// and can be used directly
if (!a.has_primitive()) {
trace_to_real.insert({a.id(), a});
} else {
// Find real inputs
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
if (a.siblings().empty()) {
auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
} else {
// Ensure the order is correct for multi-output primitives
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
auto trace_out = a.outputs();
for (auto& o : trace_out) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) {
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
}
}
}
}
std::vector<array> outputs;
for (auto& o : trace_outputs) {
outputs.push_back(trace_to_real.at(o.id()));
}
return outputs;
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) {
if (compiler_disabled()) {
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
// Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs);
// No matching cache entry existed, so compile
if (entry.empty) {
// Mark the entry as not empty since we are about to fill it
entry.empty = false;
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
// DFS the graph and get a tape, and a map of array id to (parent,
// position in parent inputs)
std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
std::tie(entry.tape, parents_map) =
compile_dfs(entry.inputs, entry.outputs);
// Simplify the tape
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
// This is a good point to do more optimizations, e.g. kernel fusion to
// generate new primitives. The tape needs to be updated accordingly
}
// At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
};
}
void compile_erase(size_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
if (detail::compiler_disabled()) {
return fun;
}
auto fun_id = detail::getAddress(fun);
return detail::compile(fun, fun_id);
}
void disable_compile() {
detail::compiler_disabled() = true;
}
void enable_compile() {
detail::compiler_disabled() = false;
}
} // namespace mlx::core

View File

@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <future>
#include <map>
#include <numeric>
#include <set>
#include <sstream>
@@ -35,169 +34,6 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void simplify(const std::vector<array>& outputs) {
// Some notes about how this function works
//
// Step 1: Traverse the graph and build a tape. During the graph
// traversal we:
// - Build a map of inputs to their parents.
// - Record scalar inputs in a map in order to fuse them.
// Step 2: Process the tape. A node in the tape has inputs and outputs.
// - Scalar inputs are replaced with their canonical scalar
// - We check each inputs output nodes. Every output node that matches
// the current node gets fused into the current node.
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
auto is_scalar = [](const array& a) {
return a.is_evaled() && a.ndim() == 0;
};
auto get_scalar_rep = [](const array& a) {
uint64_t v = 0;
int dtype;
switch (a.dtype().size) {
case 1:
v = *a.data<uint8_t>();
break;
case 4:
v = *a.data<uint32_t>();
break;
case 8:
v = *a.data<uint64_t>();
break;
}
return std::make_pair(v, a.dtype().val);
};
// DFS the graph to build the tape, and log parents and scalars
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (int i = 0; i < a.inputs().size(); i++) {
auto& in = a.inputs()[i];
parents_map[in.id()].push_back({a, i});
for (auto& s : a.siblings()) {
parents_map[in.id()].push_back({s, i});
}
recurse(in);
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
tape.push(a);
if (is_scalar(a)) {
scalars.insert({get_scalar_rep(a), a});
}
};
for (auto& a : outputs) {
recurse(a);
}
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
return false;
}
if (a.primitive_id() == b.primitive_id()) {
return false;
}
const auto& pa = a.primitive();
const auto& pb = b.primitive();
if (typeid(pa) != typeid(pb)) {
return false;
}
if (a.inputs().size() != b.inputs().size()) {
return false;
}
for (int i = 0; i < a.inputs().size(); i++) {
if (a.inputs()[i].id() != b.inputs()[i].id()) {
return false;
}
}
return pa.is_equivalent(pb);
};
// Walk the graph
while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
// Check if we can fuse scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr);
arr = scalar->second;
}
}
// Helper to check if we can fuse the parents of the
// given array
auto maybe_fuse_parents = [&](auto& a) {
auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) {
auto N = parents->second.size();
std::vector<bool> mask(N, false);
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
}
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src);
mask[j] = true;
}
}
}
}
};
maybe_fuse_parents(arr);
for (auto& s : arr.siblings()) {
maybe_fuse_parents(s);
}
}
}
void eval(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::queue<array> tape;

View File

@@ -1,18 +1,25 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "array.h"
#include "mlx/array.h"
namespace mlx::core {
/** Fuse equivalent arrays to avoid duplicate execution. */
void simplify(const std::vector<array>& outputs);
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
template <typename... Arrays>
void simplify(Arrays... outputs) {
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
}
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
void eval(const std::vector<array>& outputs);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
namespace mlx::core::detail {
@@ -14,6 +14,15 @@ std::vector<array> vmap_replace(
const std::vector<int>& in_axes,
const std::vector<int>& out_axes);
// This is not part of the general C++ API as calling with a bad id is a bad
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id);
// Erase cached compile functions
void compile_erase(size_t fun_id);
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away
// the graph.