mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
@@ -120,6 +121,7 @@ auto max_scalars(const std::vector<array>&) {
|
||||
};
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
set_compile_mode(CompileMode::no_fuse);
|
||||
{
|
||||
auto cfun = compile(add_scalars);
|
||||
auto out = cfun({});
|
||||
@@ -136,6 +138,7 @@ TEST_CASE("test simplify scalars") {
|
||||
auto d = out[2];
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
set_compile_mode(CompileMode::enabled);
|
||||
}
|
||||
|
||||
auto exp_two(const std::vector<array>& inputs) {
|
||||
@@ -144,9 +147,11 @@ auto exp_two(const std::vector<array>& inputs) {
|
||||
};
|
||||
|
||||
TEST_CASE("test simplify") {
|
||||
set_compile_mode(CompileMode::no_fuse);
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = compile(exp_two)({a})[0];
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
set_compile_mode(CompileMode::enabled);
|
||||
}
|
||||
|
||||
auto add_diff(const std::vector<array>& inputs) {
|
||||
@@ -155,9 +160,11 @@ auto add_diff(const std::vector<array>& inputs) {
|
||||
};
|
||||
|
||||
TEST_CASE("test no simplify") {
|
||||
set_compile_mode(CompileMode::no_fuse);
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = compile(add_diff)({a})[0];
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
set_compile_mode(CompileMode::enabled);
|
||||
}
|
||||
|
||||
auto multi_one(const std::vector<array>&) {
|
||||
@@ -187,6 +194,7 @@ auto multi_three(const std::vector<array>&) {
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify multi output") {
|
||||
set_compile_mode(CompileMode::no_fuse);
|
||||
{
|
||||
auto out = compile(multi_one)({});
|
||||
auto e = out[0];
|
||||
@@ -210,4 +218,372 @@ TEST_CASE("test simplify multi output") {
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||
}
|
||||
set_compile_mode(CompileMode::enabled);
|
||||
}
|
||||
|
||||
// No fusion
|
||||
auto unary_fused_0(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{exp(inputs[0])};
|
||||
}
|
||||
|
||||
// All compilable
|
||||
auto unary_fused_1(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(negative(exp(inputs[0])))};
|
||||
}
|
||||
|
||||
auto unary_fused_1_copy(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(negative(exp(inputs[0])))};
|
||||
}
|
||||
|
||||
auto unary_fused_1_diff(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(exp(negative(inputs[0])))};
|
||||
}
|
||||
|
||||
// Output into un-compilable primitive
|
||||
auto unary_fused_2(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{sum(abs(negative(exp(inputs[0]))), true)};
|
||||
}
|
||||
|
||||
// Input from un-compilable primitive
|
||||
auto unary_fused_3(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{exp(abs(negative(sum(inputs[0], true))))};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile unary fused") {
|
||||
// NB: some of these tests are brittle and may need to be
|
||||
// updated if we change compile conditions
|
||||
{
|
||||
auto cfun = compile(unary_fused_0);
|
||||
auto x = array(2.0);
|
||||
auto out = cfun({x})[0];
|
||||
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Exp));
|
||||
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(unary_fused_1);
|
||||
auto x = array(2.0);
|
||||
auto out = cfun({x})[0];
|
||||
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||
|
||||
auto expected_out = unary_fused_1({array(2.0)})[0];
|
||||
CHECK_EQ(out.item<float>(), expected_out.item<float>());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(unary_fused_2);
|
||||
auto x = array({1.0, 2.0});
|
||||
auto out = cfun({x});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
|
||||
auto& p = out[0].primitive();
|
||||
// NB: this test is brittle, will need to update
|
||||
// it if we change compile conditions
|
||||
CHECK_EQ(typeid(p), typeid(Reduce));
|
||||
auto cout = out[0].inputs()[0];
|
||||
auto& cp = cout.primitive();
|
||||
CHECK_EQ(typeid(cp), typeid(Compiled));
|
||||
CHECK_EQ(cout.inputs()[0].id(), x.id());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(unary_fused_3);
|
||||
auto x = array({1.0, 2.0});
|
||||
auto out = cfun({x});
|
||||
|
||||
auto& p = out[0].primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
auto sout = out[0].inputs()[0];
|
||||
CHECK_EQ(out[0].inputs().size(), 1);
|
||||
auto& sp = sout.primitive();
|
||||
CHECK_EQ(typeid(sp), typeid(Reduce));
|
||||
CHECK_EQ(sout.inputs()[0].id(), x.id());
|
||||
}
|
||||
|
||||
// Is equivalent works
|
||||
{
|
||||
auto out1 = compile(unary_fused_1)({array(1.0)});
|
||||
auto out2 = compile(unary_fused_1_copy)({array(1.0)});
|
||||
CHECK(out1[0].primitive().is_equivalent(out2[0].primitive()));
|
||||
auto out3 = compile(unary_fused_1_diff)({array(1.0)});
|
||||
CHECK(!out1[0].primitive().is_equivalent(out3[0].primitive()));
|
||||
}
|
||||
}
|
||||
|
||||
// All compilable
|
||||
auto binary_fused_0(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + inputs[1]};
|
||||
}
|
||||
|
||||
// Binary into unary
|
||||
auto binary_fused_1(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(inputs[0] + inputs[1])};
|
||||
}
|
||||
|
||||
// Binary into binary
|
||||
auto binary_fused_2(const std::vector<array>& inputs) {
|
||||
auto x = inputs[0] + inputs[1];
|
||||
return std::vector<array>{x + inputs[0]};
|
||||
}
|
||||
|
||||
// Binary into unary into un-compilable
|
||||
auto binary_fused_3(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{sum(abs(inputs[0] + inputs[1]), true)};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile binary fused") {
|
||||
{
|
||||
auto cfun = compile(binary_fused_0);
|
||||
auto x = array(2.0);
|
||||
auto y = array(2.0);
|
||||
auto out = cfun({x, y})[0];
|
||||
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Add));
|
||||
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(binary_fused_1);
|
||||
auto x = array(2.0);
|
||||
auto y = array(2.0);
|
||||
auto out = cfun({x, y})[0];
|
||||
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||
CHECK_EQ(out.inputs()[1].id(), y.id());
|
||||
|
||||
auto expected_out = binary_fused_1({x, y})[0];
|
||||
CHECK_EQ(out.item<float>(), expected_out.item<float>());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(binary_fused_2);
|
||||
auto x = array(2.0);
|
||||
auto y = array(2.0);
|
||||
auto out = cfun({x, y})[0];
|
||||
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||
CHECK_EQ(out.inputs()[1].id(), y.id());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(binary_fused_3);
|
||||
auto x = array({1.0, 2.0});
|
||||
auto y = array({1.0, 2.0});
|
||||
auto out = cfun({x, y})[0];
|
||||
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Reduce));
|
||||
|
||||
auto cout = out.inputs()[0];
|
||||
auto& cp = cout.primitive();
|
||||
CHECK_EQ(typeid(cp), typeid(Compiled));
|
||||
CHECK_EQ(cout.inputs()[0].id(), x.id());
|
||||
CHECK_EQ(cout.inputs()[1].id(), y.id());
|
||||
}
|
||||
}
|
||||
|
||||
auto gelu_1(const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto out = x * (1.0f + erf(x / M_SQRT2)) / 2.0f;
|
||||
return std::vector<array>{out};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile gelu") {
|
||||
{
|
||||
auto cfun = compile(gelu_1);
|
||||
auto x = array(1.0);
|
||||
auto out = cfun({x})[0];
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
CHECK_EQ(out.inputs().size(), 4);
|
||||
for (auto& in : out.inputs()) {
|
||||
CHECK(in.inputs().empty());
|
||||
}
|
||||
auto expected_out = gelu_1({x})[0];
|
||||
CHECK(allclose(out, expected_out).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(gelu_1);
|
||||
auto x = array({1.0, 0.5});
|
||||
auto out = cfun({x})[0];
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
CHECK_EQ(out.inputs().size(), 4);
|
||||
for (auto& in : out.inputs()) {
|
||||
CHECK(in.inputs().empty());
|
||||
}
|
||||
|
||||
auto expected_out = gelu_1({x})[0];
|
||||
CHECK(allclose(out, expected_out).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
// Uncompilable input outside fused tape
|
||||
auto unary_with_two_outputs(const std::vector<array>& inputs) {
|
||||
auto x = exp(inputs[0]);
|
||||
return std::vector<array>{exp(x), sum(x, true)};
|
||||
}
|
||||
|
||||
auto uncompilable_inputs(const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
return std::vector<array>{x * abs(exp(y)), sum(x, true)};
|
||||
}
|
||||
|
||||
auto uncompilable_inputs_order_matters(const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
return std::vector<array>{x / abs(exp(y)), sum(x, true)};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile tape with outside parents") {
|
||||
{
|
||||
auto cfun = compile(unary_with_two_outputs);
|
||||
auto x = array({2.0, 2.0});
|
||||
auto out = cfun({x});
|
||||
|
||||
auto& p1 = out[0].primitive();
|
||||
CHECK_EQ(typeid(p1), typeid(Exp));
|
||||
auto& p2 = out[1].primitive();
|
||||
CHECK_EQ(typeid(p2), typeid(Reduce));
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(uncompilable_inputs);
|
||||
auto x = array({2.0, 2.0});
|
||||
auto y = array({1.6, 0.6});
|
||||
auto outs = cfun({x, y});
|
||||
|
||||
auto& p1 = outs[0].primitive();
|
||||
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||
auto& p2 = outs[1].primitive();
|
||||
CHECK_EQ(typeid(p2), typeid(Reduce));
|
||||
CHECK_EQ(outs[0].inputs().size(), 2);
|
||||
|
||||
auto expected_outs = uncompilable_inputs({x, y});
|
||||
CHECK(allclose(outs[0], expected_outs[0]).item<bool>());
|
||||
CHECK(allclose(outs[1], expected_outs[1]).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(uncompilable_inputs_order_matters);
|
||||
auto x = array({2.0, 2.0});
|
||||
auto y = array({1.6, 0.6});
|
||||
auto outs = cfun({x, y});
|
||||
|
||||
auto& p1 = outs[0].primitive();
|
||||
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||
auto& p2 = outs[1].primitive();
|
||||
CHECK_EQ(typeid(p2), typeid(Reduce));
|
||||
CHECK_EQ(outs[0].inputs().size(), 2);
|
||||
|
||||
auto expected_outs = uncompilable_inputs_order_matters({x, y});
|
||||
CHECK(allclose(outs[0], expected_outs[0]).item<bool>());
|
||||
CHECK(allclose(outs[1], expected_outs[1]).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
auto compile_accross_streams(const std::vector<array>& inputs) {
|
||||
auto s2 = new_stream(default_device());
|
||||
auto x = exp(abs(inputs[0]));
|
||||
auto y = exp(abs(x, s2), s2);
|
||||
return std::vector<array>{y};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile accross streams") {
|
||||
auto cfun = compile(compile_accross_streams);
|
||||
auto x = array({2.0f});
|
||||
auto out = cfun({x})[0];
|
||||
auto& p1 = out.primitive();
|
||||
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||
CHECK_EQ(out.inputs().size(), 1);
|
||||
auto child = out.inputs()[0];
|
||||
auto& p2 = child.primitive();
|
||||
CHECK_EQ(typeid(p2), typeid(Compiled));
|
||||
CHECK_EQ(child.inputs()[0].id(), x.id());
|
||||
}
|
||||
|
||||
auto unary_compile_outputs(const std::vector<array>& inputs) {
|
||||
auto x = abs(inputs[0]);
|
||||
auto y = square(x);
|
||||
return std::vector<array>{x, y};
|
||||
}
|
||||
|
||||
auto binary_compile_outputs(const std::vector<array>& inputs) {
|
||||
auto x = inputs[0];
|
||||
auto y = inputs[1];
|
||||
x = x + y;
|
||||
y = x + y;
|
||||
return std::vector<array>{x, y};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile internal output") {
|
||||
{
|
||||
auto cfun = compile(unary_compile_outputs);
|
||||
auto x = array({3, -2});
|
||||
auto outs = cfun({x});
|
||||
auto& p1 = outs[0].primitive();
|
||||
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||
auto& p2 = outs[1].primitive();
|
||||
CHECK_EQ(typeid(p2), typeid(Compiled));
|
||||
CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());
|
||||
auto expected_outs = unary_compile_outputs({x});
|
||||
CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());
|
||||
CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(binary_compile_outputs);
|
||||
auto x = array({3, -2});
|
||||
auto y = array({1, -1});
|
||||
auto outs = cfun({x, y});
|
||||
auto& p1 = outs[0].primitive();
|
||||
CHECK_EQ(typeid(p1), typeid(Compiled));
|
||||
auto& p2 = outs[1].primitive();
|
||||
CHECK_EQ(typeid(p2), typeid(Compiled));
|
||||
auto expected_outs = binary_compile_outputs({x, y});
|
||||
CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());
|
||||
CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
auto deep_unary_compile(const std::vector<array>& inputs) {
|
||||
auto x = inputs[0];
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
x = cos(sin(x));
|
||||
}
|
||||
return std::vector<array>{x};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile deep graph") {
|
||||
auto cfun = compile(deep_unary_compile);
|
||||
auto x = array({3.0f, -2.0f});
|
||||
auto out = cfun({x})[0];
|
||||
auto expected_out = deep_unary_compile({x})[0];
|
||||
CHECK(allclose(out, expected_out).item<bool>());
|
||||
}
|
||||
|
||||
auto repeat_input_to_compiled(const std::vector<array>& inputs) {
|
||||
auto x = abs(exp(inputs[0]));
|
||||
auto y = abs(exp(sum(x)));
|
||||
return std::vector<array>{x + y};
|
||||
}
|
||||
|
||||
TEST_CASE("test compile repeat input") {
|
||||
auto cfun = compile(repeat_input_to_compiled);
|
||||
auto x = array({3.0f, -2.0f});
|
||||
auto out = cfun({x})[0];
|
||||
auto expected_out = repeat_input_to_compiled({x})[0];
|
||||
CHECK(allclose(out, expected_out).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user