Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-19 21:43:54 -08:00
committed by GitHub
parent d0fda82595
commit 5798256fcf
14 changed files with 645 additions and 113 deletions

View File

@@ -624,31 +624,23 @@ TEST_CASE("test transform compiled function") {
CHECK(!outs[0].inputs()[1].has_primitive());
}
TEST_CASE("test metal fusion kernel reuse") {
if (default_device() != Device::gpu) {
return;
}
TEST_CASE("test fusion kernel reuse") {
auto cfun = compile(gelu_1);
auto x = array({2.0f, -2.0f});
auto y = cfun({x})[0];
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
eval(y);
std::string lib_name = p->metal_lib_name();
std::string lib_source = p->metal_lib_source();
std::string lib_name = p->lib_name();
CHECK(!lib_name.empty());
CHECK(!lib_source.empty());
x = astype(reshape(arange(10), {2, 5}), float32);
auto z = cfun({x})[0];
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
eval(z);
std::string lib_name_z = pz->metal_lib_name();
std::string lib_source_z = pz->metal_lib_source();
std::string lib_name_z = pz->lib_name();
CHECK(!lib_name_z.empty());
CHECK(lib_source_z.empty());
CHECK_EQ(lib_name, lib_name_z);
}
@@ -657,29 +649,57 @@ auto add3(const std::vector<array>& xs) {
return std::vector<array>{xs[0] + xs[0] + xs[0]};
}
TEST_CASE("test metal fusion types") {
if (default_device() != Device::gpu) {
return;
}
TEST_CASE("test fusion types") {
auto cfun = compile(add3);
auto x = array({2.0f, -2.0f});
auto y = cfun({x})[0];
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
eval(y);
std::string lib_name = p->metal_lib_name();
std::string lib_source = p->metal_lib_source();
std::string lib_name = p->lib_name();
CHECK(!lib_name.empty());
CHECK(!lib_source.empty());
x = array({2, -2}, int32);
auto z = cfun({x})[0];
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
eval(z);
std::string lib_name_z = pz->metal_lib_name();
std::string lib_source_z = pz->metal_lib_source();
std::string lib_name_z = pz->lib_name();
CHECK(!lib_name_z.empty());
CHECK(!lib_source_z.empty());
}
auto compile_shapeless_not_ok(const std::vector<array>& inputs) {
auto x = reshape(inputs[0], {2, 2});
return std::vector<array>{x};
}
auto compile_shapeless_ok(const std::vector<array>& inputs) {
auto x = inputs[0] + array({2});
return std::vector<array>{x};
}
TEST_CASE("test shapeless compile") {
{
auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true);
CHECK_THROWS(cfun({array({1, 2, 3, 4})}));
}
{
auto cfun = compile(compile_shapeless_ok, /* shapeless */ true);
auto out = cfun({array({1, 2})})[0];
auto out2 = cfun({array({1, 2, 3, 4})})[0];
// Not making a new constant array since no recompile,
// hence the ids should be the same
CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id());
CHECK(array_equal(out2, array({3, 4, 5, 6})).item<bool>());
// Recompile since type changes
out2 = cfun({array({1.0, 2.0})})[0];
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
// Recompile since ndim changes
out2 = cfun({array({1.0, 2.0}, {1, 2})})[0];
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
}
}