Kernel generation (#614)

Generate reusable element-wise kernels given a computation graph.
This commit is contained in:
Angelos Katharopoulos
2024-02-07 13:15:59 -08:00
committed by GitHub
parent 5fd11c347d
commit 28eac18571
19 changed files with 1302 additions and 459 deletions

View File

@@ -623,3 +623,63 @@ TEST_CASE("test transform compiled function") {
CHECK(!outs[0].inputs()[0].has_primitive());
CHECK(!outs[0].inputs()[1].has_primitive());
}
TEST_CASE("test metal fusion kernel reuse") {
if (default_device() != Device::gpu) {
return;
}
auto cfun = compile(gelu_1);
auto x = array({2.0f, -2.0f});
auto y = cfun({x})[0];
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
eval(y);
std::string lib_name = p->metal_lib_name();
std::string lib_source = p->metal_lib_source();
CHECK(!lib_name.empty());
CHECK(!lib_source.empty());
x = astype(reshape(arange(10), {2, 5}), float32);
auto z = cfun({x})[0];
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
eval(z);
std::string lib_name_z = pz->metal_lib_name();
std::string lib_source_z = pz->metal_lib_source();
CHECK(!lib_name_z.empty());
CHECK(lib_source_z.empty());
CHECK_EQ(lib_name, lib_name_z);
}
auto add3(const std::vector<array>& xs) {
return std::vector<array>{xs[0] + xs[0] + xs[0]};
}
TEST_CASE("test metal fusion types") {
if (default_device() != Device::gpu) {
return;
}
auto cfun = compile(add3);
auto x = array({2.0f, -2.0f});
auto y = cfun({x})[0];
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
eval(y);
std::string lib_name = p->metal_lib_name();
std::string lib_source = p->metal_lib_source();
CHECK(!lib_name.empty());
CHECK(!lib_source.empty());
x = array({2, -2}, int32);
auto z = cfun({x})[0];
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
eval(z);
std::string lib_name_z = pz->metal_lib_name();
std::string lib_source_z = pz->metal_lib_source();
CHECK(!lib_name_z.empty());
CHECK(!lib_source_z.empty());
}