mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
This commit is contained in:

committed by
GitHub

parent
5fd11c347d
commit
28eac18571
@@ -623,3 +623,63 @@ TEST_CASE("test transform compiled function") {
|
||||
CHECK(!outs[0].inputs()[0].has_primitive());
|
||||
CHECK(!outs[0].inputs()[1].has_primitive());
|
||||
}
|
||||
|
||||
TEST_CASE("test metal fusion kernel reuse") {
|
||||
if (default_device() != Device::gpu) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto cfun = compile(gelu_1);
|
||||
auto x = array({2.0f, -2.0f});
|
||||
auto y = cfun({x})[0];
|
||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||
eval(y);
|
||||
|
||||
std::string lib_name = p->metal_lib_name();
|
||||
std::string lib_source = p->metal_lib_source();
|
||||
CHECK(!lib_name.empty());
|
||||
CHECK(!lib_source.empty());
|
||||
|
||||
x = astype(reshape(arange(10), {2, 5}), float32);
|
||||
auto z = cfun({x})[0];
|
||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||
eval(z);
|
||||
|
||||
std::string lib_name_z = pz->metal_lib_name();
|
||||
std::string lib_source_z = pz->metal_lib_source();
|
||||
CHECK(!lib_name_z.empty());
|
||||
CHECK(lib_source_z.empty());
|
||||
|
||||
CHECK_EQ(lib_name, lib_name_z);
|
||||
}
|
||||
|
||||
auto add3(const std::vector<array>& xs) {
|
||||
return std::vector<array>{xs[0] + xs[0] + xs[0]};
|
||||
}
|
||||
|
||||
TEST_CASE("test metal fusion types") {
|
||||
if (default_device() != Device::gpu) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto cfun = compile(add3);
|
||||
auto x = array({2.0f, -2.0f});
|
||||
auto y = cfun({x})[0];
|
||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||
eval(y);
|
||||
|
||||
std::string lib_name = p->metal_lib_name();
|
||||
std::string lib_source = p->metal_lib_source();
|
||||
CHECK(!lib_name.empty());
|
||||
CHECK(!lib_source.empty());
|
||||
|
||||
x = array({2, -2}, int32);
|
||||
auto z = cfun({x})[0];
|
||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||
eval(z);
|
||||
|
||||
std::string lib_name_z = pz->metal_lib_name();
|
||||
std::string lib_source_z = pz->metal_lib_source();
|
||||
CHECK(!lib_name_z.empty());
|
||||
CHECK(!lib_source_z.empty());
|
||||
}
|
||||
|
Reference in New Issue
Block a user