mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
parent
31fea3758e
commit
d75ae52ecd
@ -33,10 +33,12 @@ DEFAULT(ArgSort)
|
|||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
|
DEFAULT_MULTI(Compiled)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT_MULTI(CustomVJP)
|
DEFAULT_MULTI(CustomVJP)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(Equal)
|
DEFAULT(Equal)
|
||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
@ -57,6 +59,7 @@ DEFAULT(Minimum)
|
|||||||
DEFAULT(NotEqual)
|
DEFAULT(NotEqual)
|
||||||
DEFAULT(Pad)
|
DEFAULT(Pad)
|
||||||
DEFAULT(Partition)
|
DEFAULT(Partition)
|
||||||
|
DEFAULT_MULTI(QRF)
|
||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reshape)
|
DEFAULT(Reshape)
|
||||||
DEFAULT(Round)
|
DEFAULT(Round)
|
||||||
@ -68,8 +71,6 @@ DEFAULT_MULTI(Split)
|
|||||||
DEFAULT(Sort)
|
DEFAULT(Sort)
|
||||||
DEFAULT(StopGradient)
|
DEFAULT(StopGradient)
|
||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
DEFAULT_MULTI(DivMod)
|
|
||||||
DEFAULT_MULTI(QRF)
|
|
||||||
|
|
||||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
@ -3,6 +3,7 @@ target_sources(
|
|||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||||
|
59
mlx/backend/common/compiled.cpp
Normal file
59
mlx/backend/common/compiled.cpp
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Build the real tape
|
||||||
|
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
|
||||||
|
const std::vector<array>& trace_tape,
|
||||||
|
const std::vector<array>& trace_inputs,
|
||||||
|
const std::vector<array>& trace_outputs,
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||||
|
}
|
||||||
|
std::queue<array> tape;
|
||||||
|
for (auto& a : trace_tape) {
|
||||||
|
// Find real inputs
|
||||||
|
std::vector<array> real_inputs;
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||||
|
}
|
||||||
|
tape.push(
|
||||||
|
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
|
||||||
|
trace_to_real.insert({a.id(), tape.back()});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> outputs;
|
||||||
|
for (auto& o : trace_outputs) {
|
||||||
|
outputs.push_back(trace_to_real.at(o.id()));
|
||||||
|
}
|
||||||
|
return {tape, outputs};
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compiled::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Make the a real tape from the tracers
|
||||||
|
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
|
||||||
|
|
||||||
|
// Run the tape
|
||||||
|
while (!tape.empty()) {
|
||||||
|
auto a = std::move(tape.front());
|
||||||
|
tape.pop();
|
||||||
|
auto outputs = a.outputs();
|
||||||
|
a.primitive().eval_cpu(a.inputs(), outputs);
|
||||||
|
a.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy results into outputs
|
||||||
|
for (int o = 0; o < real_outputs.size(); ++o) {
|
||||||
|
outputs[o].copy_shared_buffer(real_outputs[o]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -41,7 +41,9 @@ DEFAULT(ArgSort)
|
|||||||
DEFAULT(AsType)
|
DEFAULT(AsType)
|
||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
|
DEFAULT_MULTI(Compiled)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
DEFAULT(Convolution)
|
DEFAULT(Convolution)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
@ -78,6 +80,7 @@ DEFAULT(NotEqual)
|
|||||||
DEFAULT(Pad)
|
DEFAULT(Pad)
|
||||||
DEFAULT(Partition)
|
DEFAULT(Partition)
|
||||||
DEFAULT(Power)
|
DEFAULT(Power)
|
||||||
|
DEFAULT_MULTI(QRF)
|
||||||
DEFAULT(QuantizedMatmul)
|
DEFAULT(QuantizedMatmul)
|
||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reduce)
|
DEFAULT(Reduce)
|
||||||
@ -100,8 +103,6 @@ DEFAULT(Subtract)
|
|||||||
DEFAULT(Tan)
|
DEFAULT(Tan)
|
||||||
DEFAULT(Tanh)
|
DEFAULT(Tanh)
|
||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
DEFAULT_MULTI(DivMod)
|
|
||||||
DEFAULT_MULTI(QRF)
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ target_sources(
|
|||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
44
mlx/backend/metal/compiled.cpp
Normal file
44
mlx/backend/metal/compiled.cpp
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void Compiled::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Just a fall-back to the original tape for now
|
||||||
|
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
trace_to_real.insert({inputs_[i].id(), inputs[i]});
|
||||||
|
}
|
||||||
|
for (int i = 0; i < outputs.size(); ++i) {
|
||||||
|
trace_to_real.insert({outputs_[i].id(), outputs[i]});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& a : tape_) {
|
||||||
|
std::vector<array> p_inputs;
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
p_inputs.push_back(trace_to_real.at(in.id()));
|
||||||
|
}
|
||||||
|
// If a is an output get it from the map, otherwise create it
|
||||||
|
// NB this is safe as long as no multi-output sub primitves are allowed
|
||||||
|
// in Compiled
|
||||||
|
std::vector<array> p_outputs;
|
||||||
|
if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) {
|
||||||
|
p_outputs.push_back(it->second);
|
||||||
|
} else {
|
||||||
|
p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {}));
|
||||||
|
trace_to_real.insert({a.id(), p_outputs[0]});
|
||||||
|
}
|
||||||
|
a.primitive().eval_gpu(p_inputs, p_outputs);
|
||||||
|
}
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[trace_to_real](MTL::CommandBuffer*) mutable {});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -32,6 +32,7 @@ NO_GPU(AsType)
|
|||||||
NO_GPU(AsStrided)
|
NO_GPU(AsStrided)
|
||||||
NO_GPU(Broadcast)
|
NO_GPU(Broadcast)
|
||||||
NO_GPU(Ceil)
|
NO_GPU(Ceil)
|
||||||
|
NO_GPU_MULTI(Compiled)
|
||||||
NO_GPU(Concatenate)
|
NO_GPU(Concatenate)
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
NO_GPU(Copy)
|
NO_GPU(Copy)
|
||||||
@ -40,6 +41,7 @@ NO_GPU(Cosh)
|
|||||||
NO_GPU_MULTI(CustomVJP)
|
NO_GPU_MULTI(CustomVJP)
|
||||||
NO_GPU_MULTI(Depends)
|
NO_GPU_MULTI(Depends)
|
||||||
NO_GPU(Divide)
|
NO_GPU(Divide)
|
||||||
|
NO_GPU_MULTI(DivMod)
|
||||||
NO_GPU(Remainder)
|
NO_GPU(Remainder)
|
||||||
NO_GPU(Equal)
|
NO_GPU(Equal)
|
||||||
NO_GPU(Erf)
|
NO_GPU(Erf)
|
||||||
@ -69,6 +71,7 @@ NO_GPU(NotEqual)
|
|||||||
NO_GPU(Pad)
|
NO_GPU(Pad)
|
||||||
NO_GPU(Partition)
|
NO_GPU(Partition)
|
||||||
NO_GPU(Power)
|
NO_GPU(Power)
|
||||||
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(RandomBits)
|
NO_GPU(RandomBits)
|
||||||
NO_GPU(Reduce)
|
NO_GPU(Reduce)
|
||||||
@ -91,6 +94,5 @@ NO_GPU(Subtract)
|
|||||||
NO_GPU(Tan)
|
NO_GPU(Tan)
|
||||||
NO_GPU(Tanh)
|
NO_GPU(Tanh)
|
||||||
NO_GPU(Transpose)
|
NO_GPU(Transpose)
|
||||||
NO_GPU_MULTI(DivMod)
|
|
||||||
NO_GPU_MULTI(QRF)
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
468
mlx/compile.cpp
468
mlx/compile.cpp
@ -1,36 +1,198 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/compile.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace detail {
|
constexpr int max_compile_depth = 6;
|
||||||
|
|
||||||
bool& compiler_disabled() {
|
bool is_unary(const Primitive& p) {
|
||||||
auto get_val = []() {
|
return (
|
||||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
|
||||||
return true;
|
typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
|
||||||
} else {
|
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
||||||
return false;
|
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
||||||
}
|
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
||||||
};
|
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
|
||||||
static bool compiler_disabled_ = get_val();
|
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
|
||||||
return compiler_disabled_;
|
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
|
||||||
|
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
|
||||||
|
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
|
||||||
|
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
|
||||||
|
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||||
|
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||||
|
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||||
|
typeid(p) == typeid(Tanh));
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
bool is_binary(const Primitive& p) {
|
||||||
|
return (
|
||||||
|
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
|
||||||
|
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
|
||||||
|
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
|
||||||
|
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
|
||||||
|
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
|
||||||
|
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||||
|
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||||
|
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||||
|
typeid(p) == typeid(Subtract));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_broadcast(const Primitive& p) {
|
||||||
|
return typeid(p) == typeid(Broadcast);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_noop(const Primitive& p) {
|
||||||
|
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_fusable(const Primitive& p) {
|
||||||
|
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
std::vector<array> compile_replace(
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::vector<array>& trace_inputs,
|
||||||
|
const std::vector<array>& trace_outputs,
|
||||||
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
Compiled::Compiled(
|
||||||
|
Stream stream,
|
||||||
|
std::vector<array> inputs,
|
||||||
|
std::vector<array> outputs,
|
||||||
|
std::vector<array> tape,
|
||||||
|
std::unordered_set<uintptr_t> constant_ids)
|
||||||
|
: Primitive(stream),
|
||||||
|
inputs_(std::move(inputs)),
|
||||||
|
outputs_(std::move(outputs)),
|
||||||
|
tape_(std::move(tape)),
|
||||||
|
constant_ids_(std::move(constant_ids)) {}
|
||||||
|
|
||||||
|
std::vector<array> Compiled::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
|
auto fun = [this](const std::vector<array>& inputs) {
|
||||||
|
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents);
|
||||||
|
std::vector<array> vjp_outs;
|
||||||
|
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||||
|
if (i < argnums.size() && i == argnums[j]) {
|
||||||
|
vjp_outs.push_back(vjps[i]);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vjp_outs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Compiled::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
auto fun = [this](const std::vector<array>& inputs) {
|
||||||
|
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [_, jvps] = mlx::core::jvp(fun, primals, tangents);
|
||||||
|
std::vector<array> jvp_outs;
|
||||||
|
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||||
|
if (i < argnums.size() && i == argnums[j]) {
|
||||||
|
jvp_outs.push_back(jvps[i]);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jvp_outs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto fun = [this](const std::vector<array>& inputs) {
|
||||||
|
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto outputs = mlx::core::vmap(fun, axes)(inputs);
|
||||||
|
auto out_axes = std::vector<int>(outputs.size(), 0);
|
||||||
|
return {outputs, out_axes};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Compiled::is_equivalent(const Primitive& other) const {
|
||||||
|
const Compiled& a_other = static_cast<const Compiled&>(other);
|
||||||
|
return std::equal(
|
||||||
|
tape_.begin(),
|
||||||
|
tape_.end(),
|
||||||
|
a_other.tape_.begin(),
|
||||||
|
a_other.tape_.end(),
|
||||||
|
[](const array& a1, const array& a2) {
|
||||||
|
auto& p1 = a1.primitive();
|
||||||
|
auto& p2 = a2.primitive();
|
||||||
|
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compiled::print(std::ostream& os) {
|
||||||
|
os << "Compiled";
|
||||||
|
for (auto& a : tape_) {
|
||||||
|
a.primitive().print(os);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
CompileMode& compile_mode() {
|
||||||
|
auto get_val = []() {
|
||||||
|
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||||
|
return CompileMode::disabled;
|
||||||
|
} else {
|
||||||
|
return CompileMode::enabled;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
static CompileMode compile_mode_ = get_val();
|
||||||
|
return compile_mode_;
|
||||||
|
}
|
||||||
|
|
||||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||||
using ParentsMap =
|
using ParentsMap =
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||||
|
|
||||||
|
// Helper that merges two arrays in the graph by setting the parents of the
|
||||||
|
// source to point to the destination
|
||||||
|
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||||
|
// Canonicalize the order of the primitives outputs
|
||||||
|
auto sources = src.outputs();
|
||||||
|
auto dests = dst.outputs();
|
||||||
|
// For each src parent, point it to the corresponding dst
|
||||||
|
for (int i = 0; i < sources.size(); ++i) {
|
||||||
|
auto src_parents = parents_map.find(sources[i].id());
|
||||||
|
if (src_parents == parents_map.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& pairs = parents_map[dests[i].id()];
|
||||||
|
for (auto& parent : src_parents->second) {
|
||||||
|
parent.first.inputs()[parent.second] = dests[i];
|
||||||
|
pairs.push_back(parent);
|
||||||
|
}
|
||||||
|
// Remove the source from the map to avoid fusing with it again
|
||||||
|
parents_map.erase(src_parents);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename... U>
|
template <typename T, typename... U>
|
||||||
size_t getAddress(std::function<T(U...)> f) {
|
size_t getAddress(std::function<T(U...)> f) {
|
||||||
typedef T(fnType)(U...);
|
typedef T(fnType)(U...);
|
||||||
@ -59,9 +221,10 @@ struct CompilerCache {
|
|||||||
auto is_match = [](const std::vector<array>& in1,
|
auto is_match = [](const std::vector<array>& in1,
|
||||||
const std::vector<array>& in2) {
|
const std::vector<array>& in2) {
|
||||||
if (in1.size() != in2.size()) {
|
if (in1.size() != in2.size()) {
|
||||||
throw std::runtime_error(
|
std::ostringstream msg;
|
||||||
"[compiler] Got different number of inputs to function,"
|
msg << "[compiler] Unexpected number of inputs to compiled function:"
|
||||||
" this should never happen.");
|
<< " expected " << in2.size() << " got " << in1.size() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
for (int i = 0; i < in1.size(); ++i) {
|
for (int i = 0; i < in1.size(); ++i) {
|
||||||
if (in1[i].shape() != in2[i].shape()) {
|
if (in1[i].shape() != in2[i].shape()) {
|
||||||
@ -205,28 +368,6 @@ void compile_simplify(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
|
||||||
// source to point to the destination
|
|
||||||
auto fuse = [&](array& dst, array& src) {
|
|
||||||
// Canonicalize the order of the primitives outputs
|
|
||||||
auto sources = src.outputs();
|
|
||||||
auto dests = dst.outputs();
|
|
||||||
// For each src parent, point it to the corresponding dest
|
|
||||||
for (int i = 0; i < sources.size(); ++i) {
|
|
||||||
auto src_parents = parents_map.find(sources[i].id());
|
|
||||||
if (src_parents == parents_map.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& pairs = parents_map[dests[i].id()];
|
|
||||||
for (auto& parent : src_parents->second) {
|
|
||||||
parent.first.inputs()[parent.second] = dests[i];
|
|
||||||
pairs.push_back(parent);
|
|
||||||
}
|
|
||||||
// Remove the source from the map to avoid fusing with it again
|
|
||||||
parents_map.erase(src_parents);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Depth-1 array equivalence check.
|
// Depth-1 array equivalence check.
|
||||||
auto array_equivalent = [](const array& a, const array& b) {
|
auto array_equivalent = [](const array& a, const array& b) {
|
||||||
if (!a.has_primitive() || !b.has_primitive()) {
|
if (!a.has_primitive() || !b.has_primitive()) {
|
||||||
@ -254,33 +395,32 @@ void compile_simplify(
|
|||||||
return pa.is_equivalent(pb);
|
return pa.is_equivalent(pb);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Pass 0: fuse scalars
|
// Merge scalars
|
||||||
std::vector<array> new_tape;
|
std::vector<array> new_tape;
|
||||||
for (auto& arr : tape) {
|
for (auto& arr : tape) {
|
||||||
// Check if we can fuse scalars
|
// Check if we can merge scalars
|
||||||
if (is_scalar(arr)) {
|
if (is_scalar(arr)) {
|
||||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||||
if (scalar->second.id() != arr.id()) {
|
if (scalar->second.id() != arr.id()) {
|
||||||
fuse(scalar->second, arr);
|
merge(scalar->second, arr, parents_map);
|
||||||
// Don't keep orphaned scalars in the tape
|
// Don't keep orphaned scalars in the tape
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
new_tape.push_back(std::move(arr));
|
new_tape.push_back(std::move(arr));
|
||||||
}
|
}
|
||||||
|
|
||||||
tape = std::move(new_tape);
|
tape = std::move(new_tape);
|
||||||
|
|
||||||
std::unordered_set<uintptr_t> output_set;
|
std::unordered_set<uintptr_t> output_set;
|
||||||
for (auto& o : outputs) {
|
for (auto& o : outputs) {
|
||||||
output_set.insert(o.id());
|
output_set.insert(o.id());
|
||||||
}
|
}
|
||||||
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
|
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
||||||
for (int pass = 0; pass < passes; ++pass) {
|
for (int pass = 0; pass < passes; ++pass) {
|
||||||
for (auto& arr : tape) {
|
for (auto& arr : tape) {
|
||||||
// Helper to check if we can fuse the parents of the
|
// Helper to check if we can merge the parents of the
|
||||||
// given array
|
// given array
|
||||||
auto maybe_fuse_parents = [&](auto& a) {
|
auto maybe_merge_parents = [&](auto& a) {
|
||||||
auto parents = parents_map.find(a.id());
|
auto parents = parents_map.find(a.id());
|
||||||
if (parents != parents_map.end()) {
|
if (parents != parents_map.end()) {
|
||||||
auto N = parents->second.size();
|
auto N = parents->second.size();
|
||||||
@ -296,7 +436,7 @@ void compile_simplify(
|
|||||||
auto& src = parents->second[j].first;
|
auto& src = parents->second[j].first;
|
||||||
auto& dst = parents->second[i].first;
|
auto& dst = parents->second[i].first;
|
||||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||||
fuse(dst, src);
|
merge(dst, src, parents_map);
|
||||||
mask[j] = true;
|
mask[j] = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,9 +453,9 @@ void compile_simplify(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool discard = maybe_fuse_parents(arr);
|
bool discard = maybe_merge_parents(arr);
|
||||||
for (auto& s : arr.siblings()) {
|
for (auto& s : arr.siblings()) {
|
||||||
discard &= maybe_fuse_parents(s);
|
discard &= maybe_merge_parents(s);
|
||||||
}
|
}
|
||||||
// If an array and its siblings have no parents, and none of them are
|
// If an array and its siblings have no parents, and none of them are
|
||||||
// outputs, it is safe to remove it from the tape
|
// outputs, it is safe to remove it from the tape
|
||||||
@ -327,6 +467,216 @@ void compile_simplify(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract sub-graphs of the graph that can be compiled
|
||||||
|
// and replace them with a Compiled Primitive.
|
||||||
|
void compile_fuse(
|
||||||
|
std::vector<array>& tape,
|
||||||
|
ParentsMap& parents_map,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Track outputs to replace with new compiled outputs
|
||||||
|
std::unordered_map<uintptr_t, array> output_map;
|
||||||
|
for (auto& o : outputs) {
|
||||||
|
output_map.insert({o.id(), o});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set of inputs to distinguish constants
|
||||||
|
std::unordered_set<uintptr_t> input_ids;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
input_ids.insert(in.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Go through the tape in reverse order and check for fusable sub-graphs
|
||||||
|
std::vector<array> new_tape;
|
||||||
|
std::unordered_set<uintptr_t> global_cache;
|
||||||
|
for (int i = tape.size() - 1; i >= 0; --i) {
|
||||||
|
auto& arr = tape[i];
|
||||||
|
|
||||||
|
// Already compiled
|
||||||
|
if (global_cache.find(arr.id()) != global_cache.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two pass recursion:
|
||||||
|
// First pass:
|
||||||
|
// - Collect all the primitives which we can fuse with
|
||||||
|
// - Keeps a cache of fusable primitives which may be added out of
|
||||||
|
// DAG order. We have to determine if all of a fused primitive's
|
||||||
|
// outputs are also in the fused section, and this may not be the
|
||||||
|
// case the first time we visit it.
|
||||||
|
// Second pass:
|
||||||
|
// - Collect inputs to the new compiled primitive
|
||||||
|
// - Add fusable primitives to a tape in the correct order
|
||||||
|
|
||||||
|
std::function<void(const array&, int, const Stream&)> recurse;
|
||||||
|
std::unordered_set<uintptr_t> cache;
|
||||||
|
recurse = [&](const array& a, int depth, const Stream& s) {
|
||||||
|
if (cache.find(a.id()) != cache.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop fusing if:
|
||||||
|
// - Depth limit exceeded
|
||||||
|
// - Constant input
|
||||||
|
// - Stream mismatch
|
||||||
|
// - Non fusable primitive
|
||||||
|
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||||
|
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool all_parents_in = true;
|
||||||
|
if (depth > 0) {
|
||||||
|
// Guaranteed to have a parent since nested in the
|
||||||
|
// recursion.
|
||||||
|
auto& parents = parents_map.at(a.id());
|
||||||
|
for (auto& [p, idx] : parents) {
|
||||||
|
auto in_cache = cache.find(p.id()) != cache.end();
|
||||||
|
if (!in_cache) {
|
||||||
|
all_parents_in = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Arrays with a mix of parents outside the compilable section
|
||||||
|
// are not fusable
|
||||||
|
if (!all_parents_in) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.insert({a.id()});
|
||||||
|
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
recurse(in, depth + 1, s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (arr.has_primitive()) {
|
||||||
|
Stream s = arr.primitive().stream();
|
||||||
|
recurse(arr, 0, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not worth fusing a single primitive
|
||||||
|
if (cache.size() <= 1) {
|
||||||
|
new_tape.push_back(arr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse a second time to build the tape in the right
|
||||||
|
// order and collect the inputs
|
||||||
|
std::unordered_set<uintptr_t> input_set;
|
||||||
|
std::vector<array> inputs;
|
||||||
|
std::vector<array> fused_tape;
|
||||||
|
std::unordered_set<uintptr_t> tape_set;
|
||||||
|
std::function<void(const array&)> recurse_tape;
|
||||||
|
recurse_tape = [&](const array& a) {
|
||||||
|
if (cache.find(a.id()) == cache.end()) {
|
||||||
|
if (input_set.find(a.id()) == input_set.end()) {
|
||||||
|
input_set.insert(a.id());
|
||||||
|
inputs.push_back(a);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (tape_set.find(a.id()) != tape_set.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tape_set.insert(a.id());
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
recurse_tape(in);
|
||||||
|
}
|
||||||
|
fused_tape.push_back(a);
|
||||||
|
};
|
||||||
|
recurse_tape(arr);
|
||||||
|
|
||||||
|
std::vector<array> old_outputs;
|
||||||
|
// Add to global cache and add any global outputs to outputs
|
||||||
|
// of new primitive
|
||||||
|
for (int j = 0; j < fused_tape.size() - 1; ++j) {
|
||||||
|
auto& f = fused_tape[j];
|
||||||
|
if (output_map.find(f.id()) != output_map.end()) {
|
||||||
|
old_outputs.push_back(f);
|
||||||
|
// Parents are now siblings, update the parent map
|
||||||
|
auto& pairs = parents_map[f.id()];
|
||||||
|
pairs.erase(
|
||||||
|
std::remove_if(
|
||||||
|
pairs.begin(),
|
||||||
|
pairs.end(),
|
||||||
|
[&](auto& p) {
|
||||||
|
return cache.find(p.first.id()) != cache.end();
|
||||||
|
}),
|
||||||
|
pairs.end());
|
||||||
|
} else {
|
||||||
|
// Remove inner fused arrays parents from the parents map
|
||||||
|
// to keep the parents map in a valid state
|
||||||
|
parents_map.erase(f.id());
|
||||||
|
}
|
||||||
|
global_cache.insert({f.id()});
|
||||||
|
}
|
||||||
|
old_outputs.push_back(arr);
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> shapes;
|
||||||
|
std::vector<Dtype> types;
|
||||||
|
for (auto& o : old_outputs) {
|
||||||
|
shapes.push_back(o.shape());
|
||||||
|
types.push_back(o.dtype());
|
||||||
|
}
|
||||||
|
std::unordered_set<uintptr_t> constant_ids;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
// Scalar constant
|
||||||
|
if (in.size() == 1 && !in.has_primitive() &&
|
||||||
|
input_ids.find(in.id()) == input_ids.end()) {
|
||||||
|
constant_ids.insert(in.id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto compiled_outputs = array::make_arrays(
|
||||||
|
shapes,
|
||||||
|
types,
|
||||||
|
std::make_shared<Compiled>(
|
||||||
|
outputs.back().primitive().stream(),
|
||||||
|
inputs,
|
||||||
|
old_outputs,
|
||||||
|
std::move(fused_tape),
|
||||||
|
std::move(constant_ids)),
|
||||||
|
inputs);
|
||||||
|
|
||||||
|
// One output per primitive
|
||||||
|
new_tape.push_back(compiled_outputs.back());
|
||||||
|
|
||||||
|
// Replace inputs old parents with compiled_outputs
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto& pairs = parents_map[inputs[i].id()];
|
||||||
|
pairs.erase(
|
||||||
|
std::remove_if(
|
||||||
|
pairs.begin(),
|
||||||
|
pairs.end(),
|
||||||
|
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
|
||||||
|
pairs.end());
|
||||||
|
for (auto& o : compiled_outputs) {
|
||||||
|
pairs.push_back({o, i});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// - Update outputs parents to point to compiled outputs
|
||||||
|
// - Update any overall graph outputs to be compiled outputs
|
||||||
|
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||||
|
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
||||||
|
if (auto it = output_map.find(old_outputs[o].id());
|
||||||
|
it != output_map.end()) {
|
||||||
|
it->second = compiled_outputs[o];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::reverse(new_tape.begin(), new_tape.end());
|
||||||
|
tape = std::move(new_tape);
|
||||||
|
|
||||||
|
// Replace output with potentially compiled output
|
||||||
|
for (auto& o : outputs) {
|
||||||
|
o = output_map.at(o.id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> compile_replace(
|
std::vector<array> compile_replace(
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::vector<array>& trace_inputs,
|
const std::vector<array>& trace_inputs,
|
||||||
@ -380,7 +730,7 @@ std::vector<array> compile_replace(
|
|||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
size_t fun_id) {
|
size_t fun_id) {
|
||||||
if (compiler_disabled()) {
|
if (compile_mode() == CompileMode::disabled) {
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||||
@ -402,10 +752,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
compile_dfs(entry.inputs, entry.outputs);
|
compile_dfs(entry.inputs, entry.outputs);
|
||||||
|
|
||||||
// Simplify the tape
|
// Simplify the tape
|
||||||
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
if (compile_mode() != CompileMode::no_simplify) {
|
||||||
|
compile_simplify(
|
||||||
|
entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||||
|
}
|
||||||
|
|
||||||
// This is a good point to do more optimizations, e.g. kernel fusion to
|
// Kernel fusion to generate Compiled primitives. The tape and
|
||||||
// generate new primitives. The tape needs to be updated accordingly
|
// new outputs must be updated accordingly
|
||||||
|
if (compile_mode() != CompileMode::no_fuse) {
|
||||||
|
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we must have a tape, now replace the placeholders
|
// At this point we must have a tape, now replace the placeholders
|
||||||
@ -422,7 +778,7 @@ void compile_erase(size_t fun_id) {
|
|||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||||
if (detail::compiler_disabled()) {
|
if (detail::compile_mode() == CompileMode::disabled) {
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
auto fun_id = detail::getAddress(fun);
|
auto fun_id = detail::getAddress(fun);
|
||||||
@ -430,11 +786,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void disable_compile() {
|
void disable_compile() {
|
||||||
detail::compiler_disabled() = true;
|
detail::compile_mode() = CompileMode::disabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_compile() {
|
void enable_compile() {
|
||||||
detail::compiler_disabled() = false;
|
detail::compile_mode() = CompileMode::enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_compile_mode(CompileMode mode) {
|
||||||
|
detail::compile_mode() = mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
28
mlx/compile.h
Normal file
28
mlx/compile.h
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
||||||
|
|
||||||
|
// Compile takes a function and returns a new function
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||||
|
|
||||||
|
/** Globally disable compilation.
|
||||||
|
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||||
|
* be used to disable compilation.
|
||||||
|
*/
|
||||||
|
void disable_compile();
|
||||||
|
|
||||||
|
/** Globally enable compilation.
|
||||||
|
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
||||||
|
*/
|
||||||
|
void enable_compile();
|
||||||
|
|
||||||
|
/** Set the compiler mode to the given value. */
|
||||||
|
void set_compile_mode(CompileMode mode);
|
||||||
|
} // namespace mlx::core
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/compile.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/io/load.h"
|
#include "mlx/io/load.h"
|
||||||
@ -451,6 +453,46 @@ class Ceil : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Compiled : public Primitive {
|
||||||
|
public:
|
||||||
|
/*
|
||||||
|
* The inputs, outputs and tape are either tracers or constants.
|
||||||
|
* - The tape should not contain the inputs, but it should contain the
|
||||||
|
* outputs.
|
||||||
|
* - The tape should also have only one array per primitive for multi-output
|
||||||
|
* primitives.
|
||||||
|
* - The constant_ids contains ids of arrays in the input list that are safe
|
||||||
|
* to treat as scalar constants.
|
||||||
|
*/
|
||||||
|
explicit Compiled(
|
||||||
|
Stream stream,
|
||||||
|
std::vector<array> inputs,
|
||||||
|
std::vector<array> outputs,
|
||||||
|
std::vector<array> tape,
|
||||||
|
std::unordered_set<uintptr_t> constant_ids);
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
|
||||||
|
void print(std::ostream& os) override;
|
||||||
|
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::vector<array> inputs_;
|
||||||
|
const std::vector<array> outputs_;
|
||||||
|
const std::vector<array> tape_;
|
||||||
|
const std::unordered_set<uintptr_t> constant_ids_;
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||||
|
};
|
||||||
|
|
||||||
class Concatenate : public UnaryPrimitive {
|
class Concatenate : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Concatenate(Stream stream, int axis)
|
explicit Concatenate(Stream stream, int axis)
|
||||||
|
@ -6,21 +6,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Compile takes a function and returns a new function
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
|
||||||
|
|
||||||
/** Globally disable compilation.
|
|
||||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
|
||||||
* be used to disable compilation.
|
|
||||||
*/
|
|
||||||
void disable_compile();
|
|
||||||
|
|
||||||
/** Globally enable compilation.
|
|
||||||
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
|
||||||
*/
|
|
||||||
void enable_compile();
|
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs);
|
void eval(const std::vector<array>& outputs);
|
||||||
|
|
||||||
template <typename... Arrays>
|
template <typename... Arrays>
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/compile.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
|
@ -190,6 +190,117 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
n_enable_compiled = count_prims(cfun(x))
|
n_enable_compiled = count_prims(cfun(x))
|
||||||
self.assertEqual(n_compiled, n_enable_compiled)
|
self.assertEqual(n_compiled, n_enable_compiled)
|
||||||
|
|
||||||
|
def test_compile_two_input_grad(self):
|
||||||
|
def loss(w, x):
|
||||||
|
y = x * w
|
||||||
|
return (y * mx.exp(y)).sum()
|
||||||
|
|
||||||
|
x = mx.array([1.0, 0.5, 2.0, -0.5])
|
||||||
|
w = mx.array([-1.0, 0.3, 1.0, -0.9])
|
||||||
|
|
||||||
|
expected_grad = mx.grad(loss)(w, x)
|
||||||
|
compiled_grad = mx.compile(mx.grad(loss))(w, x)
|
||||||
|
self.assertTrue(mx.allclose(expected_grad, compiled_grad))
|
||||||
|
|
||||||
|
def test_vmap_compiled(self):
|
||||||
|
def simple_unary(x):
|
||||||
|
return -mx.exp(x)
|
||||||
|
|
||||||
|
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
|
||||||
|
|
||||||
|
expected_out = mx.vmap(simple_unary)(x)
|
||||||
|
out = mx.vmap(mx.compile(simple_unary))(x)
|
||||||
|
self.assertTrue(mx.allclose(expected_out, out))
|
||||||
|
|
||||||
|
def simple_binary(x, y):
|
||||||
|
return mx.abs(mx.exp(x + y) + y)
|
||||||
|
|
||||||
|
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
|
||||||
|
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
|
||||||
|
|
||||||
|
expected_out = mx.vmap(simple_binary)(x, y)
|
||||||
|
out = mx.vmap(mx.compile(simple_binary))(x, y)
|
||||||
|
self.assertTrue(mx.allclose(expected_out, out))
|
||||||
|
|
||||||
|
expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y)
|
||||||
|
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y)
|
||||||
|
self.assertTrue(mx.allclose(expected_out, out))
|
||||||
|
|
||||||
|
y = mx.array([0.25, -0.25])
|
||||||
|
expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y)
|
||||||
|
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y)
|
||||||
|
self.assertTrue(mx.allclose(expected_out, out))
|
||||||
|
|
||||||
|
def simple_unary_outer(x):
|
||||||
|
x = mx.abs(x)
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def simple_unary_inner(z):
|
||||||
|
return -mx.exp(x)
|
||||||
|
|
||||||
|
return simple_unary_inner(x)
|
||||||
|
|
||||||
|
expected_out = -mx.exp(mx.abs(x))
|
||||||
|
out = mx.vmap(simple_unary_outer)(x)
|
||||||
|
self.assertTrue(mx.allclose(expected_out, out))
|
||||||
|
|
||||||
|
def test_vjp_vjp_compiled(self):
|
||||||
|
def simple_unary(x):
|
||||||
|
return -mx.exp(x)
|
||||||
|
|
||||||
|
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
|
||||||
|
y = mx.array([[1.0, 1.0], [1.0, 1.0]])
|
||||||
|
|
||||||
|
expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,))
|
||||||
|
out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,))
|
||||||
|
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
|
||||||
|
self.assertTrue(mx.allclose(expected_out[0], out[0]))
|
||||||
|
|
||||||
|
expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,))
|
||||||
|
out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,))
|
||||||
|
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
|
||||||
|
self.assertTrue(mx.allclose(expected_out[0], out[0]))
|
||||||
|
|
||||||
|
def simple_binary(x, y):
|
||||||
|
return mx.abs(mx.exp(x + y) + y)
|
||||||
|
|
||||||
|
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
|
||||||
|
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
|
||||||
|
cotans = mx.ones_like(x)
|
||||||
|
|
||||||
|
expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,))
|
||||||
|
out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,))
|
||||||
|
self.assertTrue(mx.allclose(expected_out[0], out[0]))
|
||||||
|
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
|
||||||
|
self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1]))
|
||||||
|
|
||||||
|
tans = (mx.ones_like(x), mx.ones_like(y))
|
||||||
|
expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans)
|
||||||
|
out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans)
|
||||||
|
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
|
||||||
|
self.assertTrue(mx.allclose(expected_out[0], out[0]))
|
||||||
|
|
||||||
|
def test_transform_over_eval_compiled(self):
|
||||||
|
def outer(x):
|
||||||
|
y = mx.exp(mx.abs(x))
|
||||||
|
mx.eval(y)
|
||||||
|
return y.sum()
|
||||||
|
|
||||||
|
x = mx.array([2.0, -1.0, 0.5])
|
||||||
|
dfdx = mx.grad(outer)(x)
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def simple_unary(x):
|
||||||
|
return mx.exp(mx.abs(x))
|
||||||
|
|
||||||
|
def outer(x):
|
||||||
|
y = simple_unary(x)
|
||||||
|
mx.eval(y)
|
||||||
|
return y.sum()
|
||||||
|
|
||||||
|
cdfdx = mx.grad(outer)(x)
|
||||||
|
self.assertTrue(mx.allclose(dfdx, cdfdx))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
@ -120,6 +121,7 @@ auto max_scalars(const std::vector<array>&) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_CASE("test simplify scalars") {
|
TEST_CASE("test simplify scalars") {
|
||||||
|
set_compile_mode(CompileMode::no_fuse);
|
||||||
{
|
{
|
||||||
auto cfun = compile(add_scalars);
|
auto cfun = compile(add_scalars);
|
||||||
auto out = cfun({});
|
auto out = cfun({});
|
||||||
@ -136,6 +138,7 @@ TEST_CASE("test simplify scalars") {
|
|||||||
auto d = out[2];
|
auto d = out[2];
|
||||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||||
}
|
}
|
||||||
|
set_compile_mode(CompileMode::enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto exp_two(const std::vector<array>& inputs) {
|
auto exp_two(const std::vector<array>& inputs) {
|
||||||
@ -144,9 +147,11 @@ auto exp_two(const std::vector<array>& inputs) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_CASE("test simplify") {
|
TEST_CASE("test simplify") {
|
||||||
|
set_compile_mode(CompileMode::no_fuse);
|
||||||
auto a = array({1.0f, 2.0f});
|
auto a = array({1.0f, 2.0f});
|
||||||
auto b = compile(exp_two)({a})[0];
|
auto b = compile(exp_two)({a})[0];
|
||||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||||
|
set_compile_mode(CompileMode::enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto add_diff(const std::vector<array>& inputs) {
|
auto add_diff(const std::vector<array>& inputs) {
|
||||||
@ -155,9 +160,11 @@ auto add_diff(const std::vector<array>& inputs) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_CASE("test no simplify") {
|
TEST_CASE("test no simplify") {
|
||||||
|
set_compile_mode(CompileMode::no_fuse);
|
||||||
auto a = array({1.0f, 2.0f});
|
auto a = array({1.0f, 2.0f});
|
||||||
auto b = compile(add_diff)({a})[0];
|
auto b = compile(add_diff)({a})[0];
|
||||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||||
|
set_compile_mode(CompileMode::enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto multi_one(const std::vector<array>&) {
|
auto multi_one(const std::vector<array>&) {
|
||||||
@ -187,6 +194,7 @@ auto multi_three(const std::vector<array>&) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test simplify multi output") {
|
TEST_CASE("test simplify multi output") {
|
||||||
|
set_compile_mode(CompileMode::no_fuse);
|
||||||
{
|
{
|
||||||
auto out = compile(multi_one)({});
|
auto out = compile(multi_one)({});
|
||||||
auto e = out[0];
|
auto e = out[0];
|
||||||
@ -210,4 +218,372 @@ TEST_CASE("test simplify multi output") {
|
|||||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||||
}
|
}
|
||||||
|
set_compile_mode(CompileMode::enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
// No fusion
|
||||||
|
auto unary_fused_0(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{exp(inputs[0])};
|
||||||
|
}
|
||||||
|
|
||||||
|
// All compilable
|
||||||
|
auto unary_fused_1(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{abs(negative(exp(inputs[0])))};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unary_fused_1_copy(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{abs(negative(exp(inputs[0])))};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unary_fused_1_diff(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{abs(exp(negative(inputs[0])))};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output into un-compilable primitive
|
||||||
|
auto unary_fused_2(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{sum(abs(negative(exp(inputs[0]))), true)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input from un-compilable primitive
|
||||||
|
auto unary_fused_3(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{exp(abs(negative(sum(inputs[0], true))))};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile unary fused") {
|
||||||
|
// NB: some of these tests are brittle and may need to be
|
||||||
|
// updated if we change compile conditions
|
||||||
|
{
|
||||||
|
auto cfun = compile(unary_fused_0);
|
||||||
|
auto x = array(2.0);
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Exp));
|
||||||
|
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(unary_fused_1);
|
||||||
|
auto x = array(2.0);
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||||
|
|
||||||
|
auto expected_out = unary_fused_1({array(2.0)})[0];
|
||||||
|
CHECK_EQ(out.item<float>(), expected_out.item<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(unary_fused_2);
|
||||||
|
auto x = array({1.0, 2.0});
|
||||||
|
auto out = cfun({x});
|
||||||
|
CHECK_EQ(out.size(), 1);
|
||||||
|
|
||||||
|
auto& p = out[0].primitive();
|
||||||
|
// NB: this test is brittle, will need to update
|
||||||
|
// it if we change compile conditions
|
||||||
|
CHECK_EQ(typeid(p), typeid(Reduce));
|
||||||
|
auto cout = out[0].inputs()[0];
|
||||||
|
auto& cp = cout.primitive();
|
||||||
|
CHECK_EQ(typeid(cp), typeid(Compiled));
|
||||||
|
CHECK_EQ(cout.inputs()[0].id(), x.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(unary_fused_3);
|
||||||
|
auto x = array({1.0, 2.0});
|
||||||
|
auto out = cfun({x});
|
||||||
|
|
||||||
|
auto& p = out[0].primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
auto sout = out[0].inputs()[0];
|
||||||
|
CHECK_EQ(out[0].inputs().size(), 1);
|
||||||
|
auto& sp = sout.primitive();
|
||||||
|
CHECK_EQ(typeid(sp), typeid(Reduce));
|
||||||
|
CHECK_EQ(sout.inputs()[0].id(), x.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is equivalent works
|
||||||
|
{
|
||||||
|
auto out1 = compile(unary_fused_1)({array(1.0)});
|
||||||
|
auto out2 = compile(unary_fused_1_copy)({array(1.0)});
|
||||||
|
CHECK(out1[0].primitive().is_equivalent(out2[0].primitive()));
|
||||||
|
auto out3 = compile(unary_fused_1_diff)({array(1.0)});
|
||||||
|
CHECK(!out1[0].primitive().is_equivalent(out3[0].primitive()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All compilable
|
||||||
|
auto binary_fused_0(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{inputs[0] + inputs[1]};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Binary into unary
|
||||||
|
auto binary_fused_1(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{abs(inputs[0] + inputs[1])};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Binary into binary
|
||||||
|
auto binary_fused_2(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[0] + inputs[1];
|
||||||
|
return std::vector<array>{x + inputs[0]};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Binary into unary into un-compilable
|
||||||
|
auto binary_fused_3(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{sum(abs(inputs[0] + inputs[1]), true)};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile binary fused") {
|
||||||
|
{
|
||||||
|
auto cfun = compile(binary_fused_0);
|
||||||
|
auto x = array(2.0);
|
||||||
|
auto y = array(2.0);
|
||||||
|
auto out = cfun({x, y})[0];
|
||||||
|
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Add));
|
||||||
|
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(binary_fused_1);
|
||||||
|
auto x = array(2.0);
|
||||||
|
auto y = array(2.0);
|
||||||
|
auto out = cfun({x, y})[0];
|
||||||
|
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||||
|
CHECK_EQ(out.inputs()[1].id(), y.id());
|
||||||
|
|
||||||
|
auto expected_out = binary_fused_1({x, y})[0];
|
||||||
|
CHECK_EQ(out.item<float>(), expected_out.item<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(binary_fused_2);
|
||||||
|
auto x = array(2.0);
|
||||||
|
auto y = array(2.0);
|
||||||
|
auto out = cfun({x, y})[0];
|
||||||
|
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||||
|
CHECK_EQ(out.inputs()[1].id(), y.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(binary_fused_3);
|
||||||
|
auto x = array({1.0, 2.0});
|
||||||
|
auto y = array({1.0, 2.0});
|
||||||
|
auto out = cfun({x, y})[0];
|
||||||
|
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Reduce));
|
||||||
|
|
||||||
|
auto cout = out.inputs()[0];
|
||||||
|
auto& cp = cout.primitive();
|
||||||
|
CHECK_EQ(typeid(cp), typeid(Compiled));
|
||||||
|
CHECK_EQ(cout.inputs()[0].id(), x.id());
|
||||||
|
CHECK_EQ(cout.inputs()[1].id(), y.id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gelu_1(const std::vector<array>& inputs) {
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto out = x * (1.0f + erf(x / M_SQRT2)) / 2.0f;
|
||||||
|
return std::vector<array>{out};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile gelu") {
|
||||||
|
{
|
||||||
|
auto cfun = compile(gelu_1);
|
||||||
|
auto x = array(1.0);
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
CHECK_EQ(out.inputs().size(), 4);
|
||||||
|
for (auto& in : out.inputs()) {
|
||||||
|
CHECK(in.inputs().empty());
|
||||||
|
}
|
||||||
|
auto expected_out = gelu_1({x})[0];
|
||||||
|
CHECK(allclose(out, expected_out).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(gelu_1);
|
||||||
|
auto x = array({1.0, 0.5});
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
CHECK_EQ(out.inputs().size(), 4);
|
||||||
|
for (auto& in : out.inputs()) {
|
||||||
|
CHECK(in.inputs().empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto expected_out = gelu_1({x})[0];
|
||||||
|
CHECK(allclose(out, expected_out).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uncompilable input outside fused tape
|
||||||
|
auto unary_with_two_outputs(const std::vector<array>& inputs) {
|
||||||
|
auto x = exp(inputs[0]);
|
||||||
|
return std::vector<array>{exp(x), sum(x, true)};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto uncompilable_inputs(const std::vector<array>& inputs) {
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
return std::vector<array>{x * abs(exp(y)), sum(x, true)};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto uncompilable_inputs_order_matters(const std::vector<array>& inputs) {
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
return std::vector<array>{x / abs(exp(y)), sum(x, true)};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile tape with outside parents") {
|
||||||
|
{
|
||||||
|
auto cfun = compile(unary_with_two_outputs);
|
||||||
|
auto x = array({2.0, 2.0});
|
||||||
|
auto out = cfun({x});
|
||||||
|
|
||||||
|
auto& p1 = out[0].primitive();
|
||||||
|
CHECK_EQ(typeid(p1), typeid(Exp));
|
||||||
|
auto& p2 = out[1].primitive();
|
||||||
|
CHECK_EQ(typeid(p2), typeid(Reduce));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(uncompilable_inputs);
|
||||||
|
auto x = array({2.0, 2.0});
|
||||||
|
auto y = array({1.6, 0.6});
|
||||||
|
auto outs = cfun({x, y});
|
||||||
|
|
||||||
|
auto& p1 = outs[0].primitive();
|
||||||
|
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||||
|
auto& p2 = outs[1].primitive();
|
||||||
|
CHECK_EQ(typeid(p2), typeid(Reduce));
|
||||||
|
CHECK_EQ(outs[0].inputs().size(), 2);
|
||||||
|
|
||||||
|
auto expected_outs = uncompilable_inputs({x, y});
|
||||||
|
CHECK(allclose(outs[0], expected_outs[0]).item<bool>());
|
||||||
|
CHECK(allclose(outs[1], expected_outs[1]).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(uncompilable_inputs_order_matters);
|
||||||
|
auto x = array({2.0, 2.0});
|
||||||
|
auto y = array({1.6, 0.6});
|
||||||
|
auto outs = cfun({x, y});
|
||||||
|
|
||||||
|
auto& p1 = outs[0].primitive();
|
||||||
|
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||||
|
auto& p2 = outs[1].primitive();
|
||||||
|
CHECK_EQ(typeid(p2), typeid(Reduce));
|
||||||
|
CHECK_EQ(outs[0].inputs().size(), 2);
|
||||||
|
|
||||||
|
auto expected_outs = uncompilable_inputs_order_matters({x, y});
|
||||||
|
CHECK(allclose(outs[0], expected_outs[0]).item<bool>());
|
||||||
|
CHECK(allclose(outs[1], expected_outs[1]).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto compile_accross_streams(const std::vector<array>& inputs) {
|
||||||
|
auto s2 = new_stream(default_device());
|
||||||
|
auto x = exp(abs(inputs[0]));
|
||||||
|
auto y = exp(abs(x, s2), s2);
|
||||||
|
return std::vector<array>{y};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile accross streams") {
|
||||||
|
auto cfun = compile(compile_accross_streams);
|
||||||
|
auto x = array({2.0f});
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
auto& p1 = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||||
|
CHECK_EQ(out.inputs().size(), 1);
|
||||||
|
auto child = out.inputs()[0];
|
||||||
|
auto& p2 = child.primitive();
|
||||||
|
CHECK_EQ(typeid(p2), typeid(Compiled));
|
||||||
|
CHECK_EQ(child.inputs()[0].id(), x.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unary_compile_outputs(const std::vector<array>& inputs) {
|
||||||
|
auto x = abs(inputs[0]);
|
||||||
|
auto y = square(x);
|
||||||
|
return std::vector<array>{x, y};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto binary_compile_outputs(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[0];
|
||||||
|
auto y = inputs[1];
|
||||||
|
x = x + y;
|
||||||
|
y = x + y;
|
||||||
|
return std::vector<array>{x, y};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile internal output") {
|
||||||
|
{
|
||||||
|
auto cfun = compile(unary_compile_outputs);
|
||||||
|
auto x = array({3, -2});
|
||||||
|
auto outs = cfun({x});
|
||||||
|
auto& p1 = outs[0].primitive();
|
||||||
|
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||||
|
auto& p2 = outs[1].primitive();
|
||||||
|
CHECK_EQ(typeid(p2), typeid(Compiled));
|
||||||
|
CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());
|
||||||
|
auto expected_outs = unary_compile_outputs({x});
|
||||||
|
CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());
|
||||||
|
CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(binary_compile_outputs);
|
||||||
|
auto x = array({3, -2});
|
||||||
|
auto y = array({1, -1});
|
||||||
|
auto outs = cfun({x, y});
|
||||||
|
auto& p1 = outs[0].primitive();
|
||||||
|
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||||
|
auto& p2 = outs[1].primitive();
|
||||||
|
CHECK_EQ(typeid(p2), typeid(Compiled));
|
||||||
|
auto expected_outs = binary_compile_outputs({x, y});
|
||||||
|
CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());
|
||||||
|
CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto deep_unary_compile(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[0];
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
x = cos(sin(x));
|
||||||
|
}
|
||||||
|
return std::vector<array>{x};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile deep graph") {
|
||||||
|
auto cfun = compile(deep_unary_compile);
|
||||||
|
auto x = array({3.0f, -2.0f});
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
auto expected_out = deep_unary_compile({x})[0];
|
||||||
|
CHECK(allclose(out, expected_out).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto repeat_input_to_compiled(const std::vector<array>& inputs) {
|
||||||
|
auto x = abs(exp(inputs[0]));
|
||||||
|
auto y = abs(exp(sum(x)));
|
||||||
|
return std::vector<array>{x + y};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile repeat input") {
|
||||||
|
auto cfun = compile(repeat_input_to_compiled);
|
||||||
|
auto x = array({3.0f, -2.0f});
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
auto expected_out = repeat_input_to_compiled({x})[0];
|
||||||
|
CHECK(allclose(out, expected_out).item<bool>());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user