mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Compile front-end (#476)
* fix tests for linux * make a move on compile * basic compile scaffold works * compile binding * clean * fix * fix grad, more tests * basic python tests * fix segfault on python exit * compile works with python closures * fix test * fix python globals bug, and erase * simplify * more cpp tests * bug fix with move function and compile at exit * simplify inputs also * enable and disable compiler * remove simplify * simplify tests use compile now * fix multi-output with compile * clear output tree from cache when function goes out of scope * ../python/src/transforms.cpp * remove closure capture * comments
This commit is contained in:
parent
874b739f3c
commit
8fa6b322b9
@ -5,6 +5,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||||
|
@ -47,6 +47,17 @@ array::array(
|
|||||||
std::move(primitive),
|
std::move(primitive),
|
||||||
inputs)) {}
|
inputs)) {}
|
||||||
|
|
||||||
|
array::array(
|
||||||
|
std::vector<int> shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
|
std::move(shape),
|
||||||
|
dtype,
|
||||||
|
std::move(primitive),
|
||||||
|
std::move(inputs))) {}
|
||||||
|
|
||||||
std::vector<array> array::make_arrays(
|
std::vector<array> array::make_arrays(
|
||||||
const std::vector<std::vector<int>>& shapes,
|
const std::vector<std::vector<int>>& shapes,
|
||||||
const std::vector<Dtype>& dtypes,
|
const std::vector<Dtype>& dtypes,
|
||||||
@ -158,7 +169,22 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
inputs(inputs) {
|
inputs(inputs) {
|
||||||
std::tie(size, strides) = cum_prod(shape);
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
is_tracer |= in.is_tracer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array::ArrayDesc::ArrayDesc(
|
||||||
|
std::vector<int>&& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs)
|
||||||
|
: shape(std::move(shape)),
|
||||||
|
dtype(dtype),
|
||||||
|
primitive(std::move(primitive)),
|
||||||
|
inputs(std::move(inputs)) {
|
||||||
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
for (auto& in : inputs) {
|
for (auto& in : inputs) {
|
||||||
is_tracer |= in.is_tracer();
|
is_tracer |= in.is_tracer();
|
||||||
}
|
}
|
||||||
|
17
mlx/array.h
17
mlx/array.h
@ -172,6 +172,12 @@ class array {
|
|||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
const std::vector<array>& inputs);
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
array(
|
||||||
|
std::vector<int> shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs);
|
||||||
|
|
||||||
static std::vector<array> make_arrays(
|
static std::vector<array> make_arrays(
|
||||||
const std::vector<std::vector<int>>& shapes,
|
const std::vector<std::vector<int>>& shapes,
|
||||||
const std::vector<Dtype>& dtypes,
|
const std::vector<Dtype>& dtypes,
|
||||||
@ -215,6 +221,11 @@ class array {
|
|||||||
return *(array_desc_->primitive);
|
return *(array_desc_->primitive);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** A shared pointer to the array's primitive. */
|
||||||
|
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||||
|
return array_desc_->primitive;
|
||||||
|
};
|
||||||
|
|
||||||
/** Check if the array has an attached primitive or is a leaf node. */
|
/** Check if the array has an attached primitive or is a leaf node. */
|
||||||
bool has_primitive() const {
|
bool has_primitive() const {
|
||||||
return array_desc_->primitive != nullptr;
|
return array_desc_->primitive != nullptr;
|
||||||
@ -360,6 +371,12 @@ class array {
|
|||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
const std::vector<array>& inputs);
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
explicit ArrayDesc(
|
||||||
|
std::vector<int>&& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs);
|
||||||
};
|
};
|
||||||
|
|
||||||
// The ArrayDesc contains the details of the materialized array including the
|
// The ArrayDesc contains the details of the materialized array including the
|
||||||
|
440
mlx/compile.cpp
Normal file
440
mlx/compile.cpp
Normal file
@ -0,0 +1,440 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
bool& compiler_disabled() {
|
||||||
|
auto get_val = []() {
|
||||||
|
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static bool compiler_disabled_ = get_val();
|
||||||
|
return compiler_disabled_;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||||
|
|
||||||
|
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||||
|
using ParentsMap =
|
||||||
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||||
|
|
||||||
|
template <typename T, typename... U>
|
||||||
|
size_t getAddress(std::function<T(U...)> f) {
|
||||||
|
typedef T(fnType)(U...);
|
||||||
|
fnType** fnPointer = f.template target<fnType*>();
|
||||||
|
if (fnPointer == nullptr) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[compile] Cannot compile a non-addressable function.");
|
||||||
|
}
|
||||||
|
return (size_t)*fnPointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CompilerCache {
|
||||||
|
struct CacheEntry {
|
||||||
|
std::vector<array> inputs;
|
||||||
|
std::vector<array> outputs;
|
||||||
|
std::vector<array> tape;
|
||||||
|
bool empty{true};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns a reference to a CacheEntry which can be updated
|
||||||
|
// by the caller to avoid copying large tapes / inputs / outputs
|
||||||
|
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
|
||||||
|
// Try to find the entry
|
||||||
|
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||||
|
auto& entries = entry_it->second;
|
||||||
|
auto is_match = [](const std::vector<array>& in1,
|
||||||
|
const std::vector<array>& in2) {
|
||||||
|
if (in1.size() != in2.size()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[compiler] Got different number of inputs to function,"
|
||||||
|
" this should never happen.");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < in1.size(); ++i) {
|
||||||
|
if (in1[i].shape() != in2[i].shape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (in1[i].dtype() != in2[i].dtype()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||||
|
// equal. Note this could get really slow if one compiles the same
|
||||||
|
// function with many different shapes. May want to store entries in a
|
||||||
|
// more easily searchable structure.
|
||||||
|
for (auto& entry : entries) {
|
||||||
|
// Check the inputs match and return if so
|
||||||
|
if (is_match(inputs, entry.inputs)) {
|
||||||
|
return entry;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Otherwise append a new cache entry
|
||||||
|
entries.push_back(CacheEntry{});
|
||||||
|
return entries.back();
|
||||||
|
};
|
||||||
|
|
||||||
|
void erase(size_t fun_id) {
|
||||||
|
cache_.erase(fun_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CompilerCache() {
|
||||||
|
// Make sure the allocator is fully
|
||||||
|
// initialized before the compiler cache
|
||||||
|
allocator::allocator();
|
||||||
|
}
|
||||||
|
friend CompilerCache& compiler_cache();
|
||||||
|
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
||||||
|
};
|
||||||
|
|
||||||
|
CompilerCache& compiler_cache() {
|
||||||
|
static CompilerCache compiler_cache_;
|
||||||
|
return compiler_cache_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
// Set the global tracing flag.
|
||||||
|
detail::InTracing in_tracing;
|
||||||
|
|
||||||
|
// Run the function on placeholder inputs
|
||||||
|
// to get compute graph
|
||||||
|
std::vector<array> tracer_inputs;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});
|
||||||
|
in.set_tracer(true);
|
||||||
|
tracer_inputs.push_back(std::move(in));
|
||||||
|
}
|
||||||
|
return {tracer_inputs, fun(tracer_inputs)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||||
|
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
|
std::function<void(const array&)> recurse;
|
||||||
|
std::vector<array> tape;
|
||||||
|
std::unordered_set<std::uintptr_t> input_set;
|
||||||
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||||
|
parents_map;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto in = inputs[i];
|
||||||
|
input_set.insert(in.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
// DFS the graph to build the tape, and log parents and scalars
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
|
recurse = [&](const array& a) {
|
||||||
|
auto id = a.id();
|
||||||
|
if (cache.find(id) != cache.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < a.inputs().size(); i++) {
|
||||||
|
auto& in = a.inputs()[i];
|
||||||
|
parents_map[in.id()].push_back({a, i});
|
||||||
|
for (auto& s : a.siblings()) {
|
||||||
|
parents_map[in.id()].push_back({s, i});
|
||||||
|
}
|
||||||
|
// Don't recurse on inputs (but add them to the tape for the purpose
|
||||||
|
// of future optimizations)
|
||||||
|
if (input_set.find(a.id()) == input_set.end()) {
|
||||||
|
recurse(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache.insert(id);
|
||||||
|
for (auto& s : a.siblings()) {
|
||||||
|
cache.insert(s.id());
|
||||||
|
}
|
||||||
|
tape.push_back(a);
|
||||||
|
};
|
||||||
|
for (auto& a : outputs) {
|
||||||
|
recurse(a);
|
||||||
|
}
|
||||||
|
return {tape, parents_map};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simplify the tape. Note, this function modifies in-place both the tape and
|
||||||
|
// the parents map to remove orphaned arrays
|
||||||
|
void compile_simplify(
|
||||||
|
std::vector<array>& tape,
|
||||||
|
ParentsMap& parents_map,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
int passes) {
|
||||||
|
// Helpers to identify identical scalars
|
||||||
|
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||||
|
auto is_scalar = [](const array& a) {
|
||||||
|
return a.is_evaled() && a.ndim() == 0;
|
||||||
|
};
|
||||||
|
auto get_scalar_rep = [](const array& a) {
|
||||||
|
uint64_t v = 0;
|
||||||
|
int dtype;
|
||||||
|
switch (a.dtype().size) {
|
||||||
|
case 1:
|
||||||
|
v = *a.data<uint8_t>();
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
v = *a.data<uint32_t>();
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
v = *a.data<uint64_t>();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return std::make_pair(v, a.dtype().val);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& a : tape) {
|
||||||
|
if (is_scalar(a)) {
|
||||||
|
scalars.insert({get_scalar_rep(a), a});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper that fuses two arrays in the graph by setting the parents of the
|
||||||
|
// source to point to the destination
|
||||||
|
auto fuse = [&](array& dst, array& src) {
|
||||||
|
// Canonicalize the order of the primitives outputs
|
||||||
|
auto sources = src.outputs();
|
||||||
|
auto dests = dst.outputs();
|
||||||
|
// For each src parent, point it to the corresponding dest
|
||||||
|
for (int i = 0; i < sources.size(); ++i) {
|
||||||
|
auto src_parents = parents_map.find(sources[i].id());
|
||||||
|
if (src_parents == parents_map.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& pairs = parents_map[dests[i].id()];
|
||||||
|
for (auto& parent : src_parents->second) {
|
||||||
|
parent.first.inputs()[parent.second] = dests[i];
|
||||||
|
pairs.push_back(parent);
|
||||||
|
}
|
||||||
|
// Remove the source from the map to avoid fusing with it again
|
||||||
|
parents_map.erase(src_parents);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Depth-1 array equivalence check.
|
||||||
|
auto array_equivalent = [](const array& a, const array& b) {
|
||||||
|
if (!a.has_primitive() || !b.has_primitive()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (a.primitive_id() == b.primitive_id()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto& pa = a.primitive();
|
||||||
|
const auto& pb = b.primitive();
|
||||||
|
if (typeid(pa) != typeid(pb)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.inputs().size() != b.inputs().size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < a.inputs().size(); i++) {
|
||||||
|
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pa.is_equivalent(pb);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pass 0: fuse scalars
|
||||||
|
std::vector<array> new_tape;
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
// Check if we can fuse scalars
|
||||||
|
if (is_scalar(arr)) {
|
||||||
|
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||||
|
if (scalar->second.id() != arr.id()) {
|
||||||
|
fuse(scalar->second, arr);
|
||||||
|
// Don't keep orphaned scalars in the tape
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_tape.push_back(std::move(arr));
|
||||||
|
}
|
||||||
|
|
||||||
|
tape = std::move(new_tape);
|
||||||
|
|
||||||
|
std::unordered_set<uintptr_t> output_set;
|
||||||
|
for (auto& o : outputs) {
|
||||||
|
output_set.insert(o.id());
|
||||||
|
}
|
||||||
|
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
|
||||||
|
for (int pass = 0; pass < passes; ++pass) {
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
// Helper to check if we can fuse the parents of the
|
||||||
|
// given array
|
||||||
|
auto maybe_fuse_parents = [&](auto& a) {
|
||||||
|
auto parents = parents_map.find(a.id());
|
||||||
|
if (parents != parents_map.end()) {
|
||||||
|
auto N = parents->second.size();
|
||||||
|
std::vector<bool> mask(N, false);
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (mask[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int j = i + 1; j < N; j++) {
|
||||||
|
if (mask[j]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& src = parents->second[j].first;
|
||||||
|
auto& dst = parents->second[i].first;
|
||||||
|
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||||
|
fuse(dst, src);
|
||||||
|
mask[j] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Erase orphaned parents so we don't keep fusing with them
|
||||||
|
for (int i = N - 1; i > 0; --i) {
|
||||||
|
if (mask[i]) {
|
||||||
|
parents->second.erase(parents->second.begin() + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return output_set.find(a.id()) == output_set.end();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool discard = maybe_fuse_parents(arr);
|
||||||
|
for (auto& s : arr.siblings()) {
|
||||||
|
discard &= maybe_fuse_parents(s);
|
||||||
|
}
|
||||||
|
// If an array and its siblings have no parents, and none of them are
|
||||||
|
// outputs, it is safe to remove it from the tape
|
||||||
|
if (!discard) {
|
||||||
|
new_tape.push_back(std::move(arr));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tape = std::move(new_tape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> compile_replace(
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::vector<array>& trace_inputs,
|
||||||
|
const std::vector<array>& trace_outputs,
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& a : tape) {
|
||||||
|
// Arrays in the tape without primitives are constants
|
||||||
|
// and can be used directly
|
||||||
|
if (!a.has_primitive()) {
|
||||||
|
trace_to_real.insert({a.id(), a});
|
||||||
|
} else {
|
||||||
|
// Find real inputs
|
||||||
|
std::vector<array> real_inputs;
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||||
|
}
|
||||||
|
if (a.siblings().empty()) {
|
||||||
|
auto real_a = array(
|
||||||
|
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||||
|
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||||
|
} else {
|
||||||
|
// Ensure the order is correct for multi-output primitives
|
||||||
|
std::vector<std::vector<int>> shapes;
|
||||||
|
std::vector<Dtype> types;
|
||||||
|
auto trace_out = a.outputs();
|
||||||
|
for (auto& o : trace_out) {
|
||||||
|
shapes.push_back(o.shape());
|
||||||
|
types.push_back(o.dtype());
|
||||||
|
}
|
||||||
|
auto real_out =
|
||||||
|
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||||
|
for (int i = 0; i < trace_out.size(); ++i) {
|
||||||
|
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> outputs;
|
||||||
|
for (auto& o : trace_outputs) {
|
||||||
|
outputs.push_back(trace_to_real.at(o.id()));
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
size_t fun_id) {
|
||||||
|
if (compiler_disabled()) {
|
||||||
|
return fun;
|
||||||
|
}
|
||||||
|
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||||
|
// Find a cache entry with the correct inputs
|
||||||
|
auto& entry = compiler_cache().find(fun_id, inputs);
|
||||||
|
|
||||||
|
// No matching cache entry existed, so compile
|
||||||
|
if (entry.empty) {
|
||||||
|
// Mark the entry as not empty since we are about to fill it
|
||||||
|
entry.empty = false;
|
||||||
|
// Trace to build the graph
|
||||||
|
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||||
|
|
||||||
|
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||||
|
// position in parent inputs)
|
||||||
|
std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>
|
||||||
|
parents_map;
|
||||||
|
std::tie(entry.tape, parents_map) =
|
||||||
|
compile_dfs(entry.inputs, entry.outputs);
|
||||||
|
|
||||||
|
// Simplify the tape
|
||||||
|
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||||
|
|
||||||
|
// This is a good point to do more optimizations, e.g. kernel fusion to
|
||||||
|
// generate new primitives. The tape needs to be updated accordingly
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point we must have a tape, now replace the placeholders
|
||||||
|
// with real arrays that can be evaluated
|
||||||
|
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void compile_erase(size_t fun_id) {
|
||||||
|
detail::compiler_cache().erase(fun_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||||
|
if (detail::compiler_disabled()) {
|
||||||
|
return fun;
|
||||||
|
}
|
||||||
|
auto fun_id = detail::getAddress(fun);
|
||||||
|
return detail::compile(fun, fun_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void disable_compile() {
|
||||||
|
detail::compiler_disabled() = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void enable_compile() {
|
||||||
|
detail::compiler_disabled() = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <map>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@ -35,169 +34,6 @@ class Synchronizer : public Primitive {
|
|||||||
// are currently under a function transformation.
|
// are currently under a function transformation.
|
||||||
int detail::InTracing::tracing_counter{0};
|
int detail::InTracing::tracing_counter{0};
|
||||||
|
|
||||||
void simplify(const std::vector<array>& outputs) {
|
|
||||||
// Some notes about how this function works
|
|
||||||
//
|
|
||||||
// Step 1: Traverse the graph and build a tape. During the graph
|
|
||||||
// traversal we:
|
|
||||||
// - Build a map of inputs to their parents.
|
|
||||||
// - Record scalar inputs in a map in order to fuse them.
|
|
||||||
// Step 2: Process the tape. A node in the tape has inputs and outputs.
|
|
||||||
// - Scalar inputs are replaced with their canonical scalar
|
|
||||||
// - We check each inputs output nodes. Every output node that matches
|
|
||||||
// the current node gets fused into the current node.
|
|
||||||
std::function<void(const array&)> recurse;
|
|
||||||
std::queue<array> tape;
|
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
|
||||||
parents_map;
|
|
||||||
|
|
||||||
// Helpers to identify identical scalars
|
|
||||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
|
||||||
auto is_scalar = [](const array& a) {
|
|
||||||
return a.is_evaled() && a.ndim() == 0;
|
|
||||||
};
|
|
||||||
auto get_scalar_rep = [](const array& a) {
|
|
||||||
uint64_t v = 0;
|
|
||||||
int dtype;
|
|
||||||
switch (a.dtype().size) {
|
|
||||||
case 1:
|
|
||||||
v = *a.data<uint8_t>();
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
v = *a.data<uint32_t>();
|
|
||||||
break;
|
|
||||||
case 8:
|
|
||||||
v = *a.data<uint64_t>();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return std::make_pair(v, a.dtype().val);
|
|
||||||
};
|
|
||||||
|
|
||||||
// DFS the graph to build the tape, and log parents and scalars
|
|
||||||
recurse = [&](const array& a) {
|
|
||||||
auto id = a.id();
|
|
||||||
if (cache.find(id) != cache.end()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < a.inputs().size(); i++) {
|
|
||||||
auto& in = a.inputs()[i];
|
|
||||||
parents_map[in.id()].push_back({a, i});
|
|
||||||
for (auto& s : a.siblings()) {
|
|
||||||
parents_map[in.id()].push_back({s, i});
|
|
||||||
}
|
|
||||||
recurse(in);
|
|
||||||
}
|
|
||||||
cache.insert(id);
|
|
||||||
for (auto& s : a.siblings()) {
|
|
||||||
cache.insert(s.id());
|
|
||||||
}
|
|
||||||
|
|
||||||
tape.push(a);
|
|
||||||
if (is_scalar(a)) {
|
|
||||||
scalars.insert({get_scalar_rep(a), a});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (auto& a : outputs) {
|
|
||||||
recurse(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
|
||||||
// source to point to the destination
|
|
||||||
auto fuse = [&](array& dst, array& src) {
|
|
||||||
// Canonicalize the order of the primitives outputs
|
|
||||||
auto sources = src.outputs();
|
|
||||||
auto dests = dst.outputs();
|
|
||||||
// For each src parent, point it to the corresponding dest
|
|
||||||
for (int i = 0; i < sources.size(); ++i) {
|
|
||||||
auto src_parents = parents_map.find(sources[i].id());
|
|
||||||
if (src_parents == parents_map.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& pairs = parents_map[dests[i].id()];
|
|
||||||
for (auto& parent : src_parents->second) {
|
|
||||||
parent.first.inputs()[parent.second] = dests[i];
|
|
||||||
pairs.push_back(parent);
|
|
||||||
}
|
|
||||||
// Remove the source from the map to avoid fusing with it again
|
|
||||||
parents_map.erase(src_parents);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Depth-1 array equivalence check.
|
|
||||||
auto array_equivalent = [](const array& a, const array& b) {
|
|
||||||
if (!a.has_primitive() || !b.has_primitive()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (a.primitive_id() == b.primitive_id()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const auto& pa = a.primitive();
|
|
||||||
const auto& pb = b.primitive();
|
|
||||||
if (typeid(pa) != typeid(pb)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (a.inputs().size() != b.inputs().size()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < a.inputs().size(); i++) {
|
|
||||||
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pa.is_equivalent(pb);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Walk the graph
|
|
||||||
while (!tape.empty()) {
|
|
||||||
auto arr = std::move(tape.front());
|
|
||||||
tape.pop();
|
|
||||||
|
|
||||||
// Check if we can fuse scalars
|
|
||||||
if (is_scalar(arr)) {
|
|
||||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
|
||||||
if (scalar->second.id() != arr.id()) {
|
|
||||||
fuse(scalar->second, arr);
|
|
||||||
arr = scalar->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to check if we can fuse the parents of the
|
|
||||||
// given array
|
|
||||||
auto maybe_fuse_parents = [&](auto& a) {
|
|
||||||
auto parents = parents_map.find(a.id());
|
|
||||||
if (parents != parents_map.end()) {
|
|
||||||
auto N = parents->second.size();
|
|
||||||
std::vector<bool> mask(N, false);
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
if (mask[i]) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (int j = i + 1; j < N; j++) {
|
|
||||||
if (mask[j]) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& src = parents->second[j].first;
|
|
||||||
auto& dst = parents->second[i].first;
|
|
||||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
|
||||||
fuse(dst, src);
|
|
||||||
mask[j] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
maybe_fuse_parents(arr);
|
|
||||||
for (auto& s : arr.siblings()) {
|
|
||||||
maybe_fuse_parents(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs) {
|
void eval(const std::vector<array>& outputs) {
|
||||||
std::function<void(const array&)> recurse;
|
std::function<void(const array&)> recurse;
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
|
@ -1,18 +1,25 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
/** Fuse equivalent arrays to avoid duplicate execution. */
|
// Compile takes a function and returns a new function
|
||||||
void simplify(const std::vector<array>& outputs);
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||||
|
|
||||||
template <typename... Arrays>
|
/** Globally disable compilation.
|
||||||
void simplify(Arrays... outputs) {
|
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||||
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
* be used to disable compilation.
|
||||||
}
|
*/
|
||||||
|
void disable_compile();
|
||||||
|
|
||||||
|
/** Globally enable compilation.
|
||||||
|
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
||||||
|
*/
|
||||||
|
void enable_compile();
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs);
|
void eval(const std::vector<array>& outputs);
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
namespace mlx::core::detail {
|
namespace mlx::core::detail {
|
||||||
|
|
||||||
@ -14,6 +14,15 @@ std::vector<array> vmap_replace(
|
|||||||
const std::vector<int>& in_axes,
|
const std::vector<int>& in_axes,
|
||||||
const std::vector<int>& out_axes);
|
const std::vector<int>& out_axes);
|
||||||
|
|
||||||
|
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||||
|
// idea.
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
size_t fun_id);
|
||||||
|
|
||||||
|
// Erase cached compile functions
|
||||||
|
void compile_erase(size_t fun_id);
|
||||||
|
|
||||||
// Create an InTracing object during tracing operations to signify to the rest
|
// Create an InTracing object during tracing operations to signify to the rest
|
||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
// the graph.
|
// the graph.
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
@ -163,6 +162,19 @@ py::object tree_unflatten(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::object tree_unflatten_none(
|
||||||
|
py::object tree,
|
||||||
|
const std::vector<array>& values,
|
||||||
|
int index = 0) {
|
||||||
|
return tree_map(tree, [&](py::handle obj) {
|
||||||
|
if (py::isinstance<py::none>(obj)) {
|
||||||
|
return py::cast(values[index++]);
|
||||||
|
} else {
|
||||||
|
return py::cast<py::object>(obj);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
auto validate_argnums_argnames(
|
auto validate_argnums_argnames(
|
||||||
const std::optional<IntOrVec>& argnums,
|
const std::optional<IntOrVec>& argnums,
|
||||||
const StrOrVec& argnames) {
|
const StrOrVec& argnames) {
|
||||||
@ -437,6 +449,58 @@ auto py_vmap(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_map<size_t, py::object>& tree_cache() {
|
||||||
|
// This map is used to Cache the tree structure of the outputs
|
||||||
|
static std::unordered_map<size_t, py::object> tree_cache_;
|
||||||
|
return tree_cache_;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PyCompiledFun {
|
||||||
|
py::function fun;
|
||||||
|
size_t fun_id;
|
||||||
|
|
||||||
|
PyCompiledFun(const py::function& fun)
|
||||||
|
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
|
||||||
|
|
||||||
|
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||||
|
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||||
|
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||||
|
PyCompiledFun(PyCompiledFun&& other)
|
||||||
|
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||||
|
other.fun_id = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
py::object operator()(const py::args& args) {
|
||||||
|
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||||
|
// Call the python function
|
||||||
|
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||||
|
|
||||||
|
// Flatten the outputs
|
||||||
|
auto outputs = tree_flatten(py_outputs, true);
|
||||||
|
|
||||||
|
py_outputs =
|
||||||
|
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||||
|
tree_cache().insert({this->fun_id, py_outputs});
|
||||||
|
return outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Inputs must be array or tree of arrays
|
||||||
|
auto inputs = tree_flatten(args, true);
|
||||||
|
|
||||||
|
// Compile and call
|
||||||
|
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||||
|
|
||||||
|
// Put the outputs back in the container
|
||||||
|
py::object py_outputs = tree_cache().at(fun_id);
|
||||||
|
return tree_unflatten_none(py_outputs, outputs);
|
||||||
|
};
|
||||||
|
|
||||||
|
~PyCompiledFun() {
|
||||||
|
tree_cache().erase(fun_id);
|
||||||
|
detail::compile_erase(fun_id);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void init_transforms(py::module_& m) {
|
void init_transforms(py::module_& m) {
|
||||||
py::options options;
|
py::options options;
|
||||||
options.disable_function_signatures();
|
options.disable_function_signatures();
|
||||||
@ -679,45 +743,6 @@ void init_transforms(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
function: The vectorized function.
|
function: The vectorized function.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
|
||||||
"simplify",
|
|
||||||
[](const py::args& args) {
|
|
||||||
std::vector<array> arrays = tree_flatten(args);
|
|
||||||
simplify(arrays);
|
|
||||||
},
|
|
||||||
R"pbdoc(
|
|
||||||
simplify(*args) -> None
|
|
||||||
|
|
||||||
Simplify the graph that computes the arrays.
|
|
||||||
|
|
||||||
Run a few fast graph simplification operations to reuse computation and
|
|
||||||
reduce memory consumption. This function is meant to be run every time
|
|
||||||
so its overhead should be small, approximately 1ms for a graph with a
|
|
||||||
few thousand nodes.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
def foo(x):
|
|
||||||
y = x @ x
|
|
||||||
z = x @ x
|
|
||||||
return y + z
|
|
||||||
|
|
||||||
x = mx.ones((10, 10))
|
|
||||||
y = foo(x)
|
|
||||||
z = foo(x)
|
|
||||||
|
|
||||||
# Computes the matmul twice
|
|
||||||
mx.eval(y)
|
|
||||||
|
|
||||||
# Computes the matmul once
|
|
||||||
mx.simplify(z)
|
|
||||||
mx.eval(z)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Any number of arrays and/or trees of arrays to be simplified.
|
|
||||||
)pbdoc");
|
|
||||||
m.def(
|
m.def(
|
||||||
"export_to_dot",
|
"export_to_dot",
|
||||||
[](py::object file, const py::args& args) {
|
[](py::object file, const py::args& args) {
|
||||||
@ -736,4 +761,46 @@ void init_transforms(py::module_& m) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"file"_a);
|
"file"_a);
|
||||||
|
m.def(
|
||||||
|
"compile",
|
||||||
|
[](const py::function& fun) {
|
||||||
|
return py::cpp_function(PyCompiledFun{fun});
|
||||||
|
},
|
||||||
|
"fun"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
compile(fun: function) -> function
|
||||||
|
|
||||||
|
Returns a compiled function which produces the same output as ``fun``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (function): A function which takes a variable number of
|
||||||
|
:class:`array` or trees of :class:`array` and returns
|
||||||
|
a variable number of :class:`array` or trees of :class:`array`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: A compiled function which has the same input arguments
|
||||||
|
as ``fun`` and returns the the same output(s).
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"disable_compile",
|
||||||
|
&disable_compile,
|
||||||
|
R"pbdoc(
|
||||||
|
disable_compile() -> None
|
||||||
|
|
||||||
|
Globally disable compilation. Setting the environment variable
|
||||||
|
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"enable_compile",
|
||||||
|
&enable_compile,
|
||||||
|
R"pbdoc(
|
||||||
|
enable_compiler() -> None
|
||||||
|
|
||||||
|
Globally enable compilation. This will override the environment
|
||||||
|
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
|
// Register static Python object cleanup before the interpreter exits
|
||||||
|
auto atexit = py::module_::import("atexit");
|
||||||
|
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||||
}
|
}
|
||||||
|
195
python/tests/test_compile.py
Normal file
195
python/tests/test_compile.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompile(mlx_tests.MLXTestCase):
|
||||||
|
def test_simple_compile(self):
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
compiled_fn = mx.compile(fun)
|
||||||
|
compiled_fn = mx.compile(fun)
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(1.0)
|
||||||
|
out = compiled_fn(x, y)
|
||||||
|
self.assertEqual(out.item(), 2.0)
|
||||||
|
|
||||||
|
# Try again
|
||||||
|
out = compiled_fn(x, y)
|
||||||
|
self.assertEqual(out.item(), 2.0)
|
||||||
|
|
||||||
|
# Change sizes
|
||||||
|
x = mx.array([1.0, 2.0])
|
||||||
|
out = compiled_fn(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0])))
|
||||||
|
|
||||||
|
y = mx.array([1.0, 2.0])
|
||||||
|
out = compiled_fn(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0])))
|
||||||
|
|
||||||
|
# Change types
|
||||||
|
x = mx.array([1, 2], mx.int32)
|
||||||
|
y = mx.array([1, 2], mx.int32)
|
||||||
|
out = compiled_fn(x, y)
|
||||||
|
self.assertEqual(out.dtype, mx.int32)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([2, 4])))
|
||||||
|
|
||||||
|
def test_compile_grad(self):
|
||||||
|
def loss_fn(x):
|
||||||
|
return mx.exp(x).sum()
|
||||||
|
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
|
||||||
|
x = mx.array([0.5, -0.5, 1.2])
|
||||||
|
dfdx = grad_fn(x)
|
||||||
|
compile_grad_fn = mx.compile(grad_fn)
|
||||||
|
c_dfdx = grad_fn(x)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
|
||||||
|
# Run it again without calling compile
|
||||||
|
c_dfdx = compile_grad_fn(x)
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
|
||||||
|
# Run it again with calling compile
|
||||||
|
c_dfdx = mx.compile(grad_fn)(x)
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
|
||||||
|
# Value and grad
|
||||||
|
def loss_fn(x):
|
||||||
|
return mx.exp(x).sum(), mx.sin(x)
|
||||||
|
|
||||||
|
val_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||||
|
(loss, val), dfdx = val_and_grad_fn(x)
|
||||||
|
(c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
self.assertTrue(mx.allclose(c_loss, loss))
|
||||||
|
self.assertTrue(mx.allclose(c_val, val))
|
||||||
|
|
||||||
|
def test_compile_inputs_with_primitives(self):
|
||||||
|
x = mx.array([1, 2, 3])
|
||||||
|
y = mx.array([1, 2, 3])
|
||||||
|
for _ in range(5):
|
||||||
|
x = x + y
|
||||||
|
y = y + 1
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
out = fun(x, y)
|
||||||
|
|
||||||
|
x = mx.array([1, 2, 3])
|
||||||
|
y = mx.array([1, 2, 3])
|
||||||
|
for _ in range(5):
|
||||||
|
x = x + y
|
||||||
|
y = y + 1
|
||||||
|
|
||||||
|
c_out = mx.compile(fun)(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, c_out))
|
||||||
|
|
||||||
|
# Try again
|
||||||
|
c_out = mx.compile(fun)(x, y)
|
||||||
|
self.assertTrue(mx.array_equal(out, c_out))
|
||||||
|
|
||||||
|
def test_compile_with_closure(self):
|
||||||
|
x = mx.array(1)
|
||||||
|
|
||||||
|
def closure(y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
compiled = mx.compile(closure)
|
||||||
|
out = compiled(mx.array(1))
|
||||||
|
self.assertEqual(out.item(), 2)
|
||||||
|
|
||||||
|
# Try again
|
||||||
|
out = compiled(mx.array(1))
|
||||||
|
self.assertEqual(out.item(), 2)
|
||||||
|
|
||||||
|
# Change the shape of the enclosed variable
|
||||||
|
x = mx.array([1, 2])
|
||||||
|
out = compiled(mx.array(1))
|
||||||
|
|
||||||
|
# We still get the original input (closures are not updated)
|
||||||
|
self.assertEqual(out.item(), 2)
|
||||||
|
|
||||||
|
# Try with a tree of enclosed variables
|
||||||
|
x = {"a": mx.array(1), "b": mx.array(2)}
|
||||||
|
|
||||||
|
def closure(y):
|
||||||
|
return x["a"] + y + x["b"]
|
||||||
|
|
||||||
|
compiled = mx.compile(closure)
|
||||||
|
out = compiled(mx.array(1))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
# Change the shape of one input
|
||||||
|
x["a"] = mx.array([4, 5])
|
||||||
|
out = compiled(mx.array(1))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
x["b"] = mx.array([-6, -8])
|
||||||
|
out = compiled(mx.array(1))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
# Enclosed variable is not evaluated yet
|
||||||
|
x = mx.array(1)
|
||||||
|
x = x + x
|
||||||
|
|
||||||
|
def closure(y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
compiled = mx.compile(closure)
|
||||||
|
out = compiled(mx.array(2))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
# And again
|
||||||
|
out = compiled(mx.array(2))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
def test_function_creates_array(self):
|
||||||
|
def fun(x):
|
||||||
|
return x + mx.array(1)
|
||||||
|
|
||||||
|
cfun = mx.compile(fun)
|
||||||
|
out = cfun(mx.array(3))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
# And again
|
||||||
|
out = cfun(mx.array(3))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
def test_enable_disable(self):
|
||||||
|
def fun(x):
|
||||||
|
y = x + 1
|
||||||
|
z = x + 1
|
||||||
|
return y + z
|
||||||
|
|
||||||
|
def count_prims(outputs):
|
||||||
|
buf = io.StringIO()
|
||||||
|
mx.export_to_dot(buf, outputs)
|
||||||
|
buf.seek(0)
|
||||||
|
return len([l for l in buf.read().split() if "label" in l])
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
cfun = mx.compile(fun)
|
||||||
|
n_compiled = count_prims(cfun(x))
|
||||||
|
|
||||||
|
# Check disabled
|
||||||
|
mx.disable_compile()
|
||||||
|
n_uncompiled = count_prims(cfun(x))
|
||||||
|
self.assertTrue(n_compiled < n_uncompiled)
|
||||||
|
|
||||||
|
# Check renabled
|
||||||
|
mx.enable_compile()
|
||||||
|
n_enable_compiled = count_prims(cfun(x))
|
||||||
|
self.assertEqual(n_compiled, n_enable_compiled)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -20,11 +20,11 @@ target_sources(tests PRIVATE
|
|||||||
arg_reduce_tests.cpp
|
arg_reduce_tests.cpp
|
||||||
autograd_tests.cpp
|
autograd_tests.cpp
|
||||||
blas_tests.cpp
|
blas_tests.cpp
|
||||||
|
compile_tests.cpp
|
||||||
creations_tests.cpp
|
creations_tests.cpp
|
||||||
device_tests.cpp
|
device_tests.cpp
|
||||||
eval_tests.cpp
|
eval_tests.cpp
|
||||||
fft_tests.cpp
|
fft_tests.cpp
|
||||||
graph_optimize_tests.cpp
|
|
||||||
load_tests.cpp
|
load_tests.cpp
|
||||||
ops_tests.cpp
|
ops_tests.cpp
|
||||||
random_tests.cpp
|
random_tests.cpp
|
||||||
|
213
tests/compile_tests.cpp
Normal file
213
tests/compile_tests.cpp
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
std::vector<array> simple_fun(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{inputs[0] + inputs[1]};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test simple compile") {
|
||||||
|
auto compfn = compile(simple_fun);
|
||||||
|
auto out = compfn({array(1.0f), array(2.0f)})[0];
|
||||||
|
CHECK_EQ(out.item<float>(), 3.0f);
|
||||||
|
|
||||||
|
out = compfn({array(1.0f), array(2.0f)})[0];
|
||||||
|
CHECK_EQ(out.item<float>(), 3.0f);
|
||||||
|
|
||||||
|
// Change the shapes
|
||||||
|
out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0];
|
||||||
|
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||||
|
|
||||||
|
out = compfn({array(2.0f), array({1.0f, 2.0f})})[0];
|
||||||
|
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||||
|
|
||||||
|
// Change the types
|
||||||
|
out = compfn({array(2, int32), array({1.0f, 2.0f})})[0];
|
||||||
|
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||||
|
|
||||||
|
out = compfn({array(2.0f), array({1, 2}, int32)})[0];
|
||||||
|
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> grad_fun(const std::vector<array>& inputs) {
|
||||||
|
auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
|
||||||
|
return grad(loss, {0, 1})(inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile with grad") {
|
||||||
|
auto x = array(1.0f);
|
||||||
|
auto y = array(1.0f);
|
||||||
|
auto grads_expected = grad_fun({x, y});
|
||||||
|
auto grads_compile = compile(grad_fun)({x, y});
|
||||||
|
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
|
||||||
|
CHECK_EQ(grads_compile[1].item<float>(), grads_expected[1].item<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile inputs with primitive") {
|
||||||
|
auto [k1, k2] = random::split(random::key(0));
|
||||||
|
auto x = random::uniform({5, 5}, k1);
|
||||||
|
auto y = random::uniform({5, 5}, k2);
|
||||||
|
auto expected = simple_fun({x, y})[0];
|
||||||
|
|
||||||
|
x = random::uniform({5, 5}, k1);
|
||||||
|
y = random::uniform({5, 5}, k2);
|
||||||
|
auto out = compile(simple_fun)({x, y})[0];
|
||||||
|
CHECK(array_equal(expected, out).item<bool>());
|
||||||
|
|
||||||
|
// Same thing twice
|
||||||
|
out = compile(simple_fun)({x, y})[0];
|
||||||
|
CHECK(array_equal(expected, out).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> fun_creats_array(const std::vector<array>& inputs) {
|
||||||
|
return {inputs[0] + array(1.0f)};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile with created array") {
|
||||||
|
auto cfun = compile(fun_creats_array);
|
||||||
|
auto out = cfun({array(2.0f)});
|
||||||
|
CHECK_EQ(out[0].item<float>(), 3.0f);
|
||||||
|
|
||||||
|
// Try again
|
||||||
|
out = cfun({array(2.0f)});
|
||||||
|
CHECK_EQ(out[0].item<float>(), 3.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> inner_fun(const std::vector<array>& inputs) {
|
||||||
|
return {array(2) * inputs[0]};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> outer_fun(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[0] + inputs[1];
|
||||||
|
auto y = compile(inner_fun)({x})[0];
|
||||||
|
return {x + y};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test nested compile") {
|
||||||
|
auto cfun = compile(outer_fun);
|
||||||
|
auto out = cfun({array(1), array(2)})[0];
|
||||||
|
CHECK_EQ(out.item<int>(), 9);
|
||||||
|
|
||||||
|
// Try again
|
||||||
|
out = cfun({array(1), array(2)})[0];
|
||||||
|
CHECK_EQ(out.item<int>(), 9);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test enable and disable compile") {
|
||||||
|
CHECK_THROWS(compile(nullptr));
|
||||||
|
disable_compile();
|
||||||
|
compile(nullptr);
|
||||||
|
enable_compile();
|
||||||
|
CHECK_THROWS(compile(nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto add_scalars(const std::vector<array>&) {
|
||||||
|
auto a = array(-1.0f);
|
||||||
|
auto b = array(-1.0f);
|
||||||
|
return std::vector<array>{abs(a), abs(b)};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto max_scalars(const std::vector<array>&) {
|
||||||
|
auto a = array({-1.0f, 2.0f});
|
||||||
|
auto b = maximum(a, array(0.0f));
|
||||||
|
auto c = maximum(-a, array(0.0f));
|
||||||
|
auto d = b + c;
|
||||||
|
return std::vector<array>{b, c, d};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_CASE("test simplify scalars") {
|
||||||
|
{
|
||||||
|
auto cfun = compile(add_scalars);
|
||||||
|
auto out = cfun({});
|
||||||
|
auto c = out[0];
|
||||||
|
auto d = out[1];
|
||||||
|
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto a = array({-1.0f, 2.0f});
|
||||||
|
auto out = compile(max_scalars)({a});
|
||||||
|
auto b = out[0];
|
||||||
|
auto c = out[1];
|
||||||
|
auto d = out[2];
|
||||||
|
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto exp_two(const std::vector<array>& inputs) {
|
||||||
|
auto a = inputs[0];
|
||||||
|
return std::vector<array>{exp(a) + exp(a)};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_CASE("test simplify") {
|
||||||
|
auto a = array({1.0f, 2.0f});
|
||||||
|
auto b = compile(exp_two)({a})[0];
|
||||||
|
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto add_diff(const std::vector<array>& inputs) {
|
||||||
|
auto a = inputs[0];
|
||||||
|
return std::vector<array>{cos(a) + sin(a)};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_CASE("test no simplify") {
|
||||||
|
auto a = array({1.0f, 2.0f});
|
||||||
|
auto b = compile(add_diff)({a})[0];
|
||||||
|
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto multi_one(const std::vector<array>&) {
|
||||||
|
auto a = array(1.0);
|
||||||
|
auto b = array(2.0);
|
||||||
|
auto c = divmod(a, b);
|
||||||
|
auto d = divmod(a, b);
|
||||||
|
auto e = c[0] + d[0];
|
||||||
|
auto f = c[1] + d[1];
|
||||||
|
return std::vector<array>{e, f};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto multi_two(const std::vector<array>&) {
|
||||||
|
auto a = array(1.0);
|
||||||
|
auto b = array(1.0);
|
||||||
|
auto c = divmod(a, b);
|
||||||
|
return std::vector<array>{c};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto multi_three(const std::vector<array>&) {
|
||||||
|
auto a = array(1.0);
|
||||||
|
auto b = array(2.0);
|
||||||
|
auto c = divmod(a, b);
|
||||||
|
auto d = divmod(a, b);
|
||||||
|
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||||
|
return std::vector<array>{e};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test simplify multi output") {
|
||||||
|
{
|
||||||
|
auto out = compile(multi_one)({});
|
||||||
|
auto e = out[0];
|
||||||
|
auto f = out[1];
|
||||||
|
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
||||||
|
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto c = compile(multi_two)({});
|
||||||
|
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
||||||
|
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
||||||
|
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the output order of multi-output primitives
|
||||||
|
// is respected in simplification
|
||||||
|
{
|
||||||
|
auto e = compile(multi_three)({})[0];
|
||||||
|
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
||||||
|
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||||
|
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||||
|
}
|
||||||
|
}
|
@ -1,80 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
|
||||||
|
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
TEST_CASE("test simplify scalars") {
|
|
||||||
{
|
|
||||||
auto a = array(-1.0f);
|
|
||||||
auto b = array(-1.0f);
|
|
||||||
auto c = abs(a);
|
|
||||||
auto d = abs(b);
|
|
||||||
simplify({c, d});
|
|
||||||
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto a = array({-1.0f, 2.0f});
|
|
||||||
auto b = maximum(a, array(0.0f));
|
|
||||||
auto c = maximum(-a, array(0.0f));
|
|
||||||
auto d = b + c;
|
|
||||||
simplify({d});
|
|
||||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test simplify") {
|
|
||||||
auto a = array({1.0f, 2.0f});
|
|
||||||
auto b = exp(a) + exp(a);
|
|
||||||
simplify(b);
|
|
||||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test no simplify") {
|
|
||||||
auto a = array({1.0f, 2.0f});
|
|
||||||
auto b = cos(a) + sin(a);
|
|
||||||
simplify(b);
|
|
||||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test simplify multi output") {
|
|
||||||
{
|
|
||||||
auto a = array(1.0);
|
|
||||||
auto b = array(2.0);
|
|
||||||
auto c = divmod(a, b);
|
|
||||||
auto d = divmod(a, b);
|
|
||||||
auto e = c[0] + d[0];
|
|
||||||
auto f = c[1] + d[1];
|
|
||||||
|
|
||||||
simplify({e, f});
|
|
||||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
|
||||||
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto a = array(1.0);
|
|
||||||
auto b = array(1.0);
|
|
||||||
auto c = divmod(a, b);
|
|
||||||
simplify(c);
|
|
||||||
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
|
||||||
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
|
||||||
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the output order of multi-output primitives
|
|
||||||
// is respected in simplification
|
|
||||||
{
|
|
||||||
auto a = array(1.0);
|
|
||||||
auto b = array(2.0);
|
|
||||||
auto c = divmod(a, b);
|
|
||||||
auto d = divmod(a, b);
|
|
||||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
|
||||||
simplify(e);
|
|
||||||
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
|
||||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
|
||||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
|
||||||
}
|
|
||||||
}
|
|
@ -511,6 +511,7 @@ TEST_CASE("test is inf") {
|
|||||||
CHECK_FALSE(isinf(x).item<bool>());
|
CHECK_FALSE(isinf(x).item<bool>());
|
||||||
|
|
||||||
auto inf = std::numeric_limits<float>::infinity();
|
auto inf = std::numeric_limits<float>::infinity();
|
||||||
|
|
||||||
array y(inf);
|
array y(inf);
|
||||||
CHECK(isinf(y).item<bool>());
|
CHECK(isinf(y).item<bool>());
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user