mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -624,31 +624,23 @@ TEST_CASE("test transform compiled function") {
|
||||
CHECK(!outs[0].inputs()[1].has_primitive());
|
||||
}
|
||||
|
||||
TEST_CASE("test metal fusion kernel reuse") {
|
||||
if (default_device() != Device::gpu) {
|
||||
return;
|
||||
}
|
||||
|
||||
TEST_CASE("test fusion kernel reuse") {
|
||||
auto cfun = compile(gelu_1);
|
||||
auto x = array({2.0f, -2.0f});
|
||||
auto y = cfun({x})[0];
|
||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||
eval(y);
|
||||
|
||||
std::string lib_name = p->metal_lib_name();
|
||||
std::string lib_source = p->metal_lib_source();
|
||||
std::string lib_name = p->lib_name();
|
||||
CHECK(!lib_name.empty());
|
||||
CHECK(!lib_source.empty());
|
||||
|
||||
x = astype(reshape(arange(10), {2, 5}), float32);
|
||||
auto z = cfun({x})[0];
|
||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||
eval(z);
|
||||
|
||||
std::string lib_name_z = pz->metal_lib_name();
|
||||
std::string lib_source_z = pz->metal_lib_source();
|
||||
std::string lib_name_z = pz->lib_name();
|
||||
CHECK(!lib_name_z.empty());
|
||||
CHECK(lib_source_z.empty());
|
||||
|
||||
CHECK_EQ(lib_name, lib_name_z);
|
||||
}
|
||||
@@ -657,29 +649,57 @@ auto add3(const std::vector<array>& xs) {
|
||||
return std::vector<array>{xs[0] + xs[0] + xs[0]};
|
||||
}
|
||||
|
||||
TEST_CASE("test metal fusion types") {
|
||||
if (default_device() != Device::gpu) {
|
||||
return;
|
||||
}
|
||||
|
||||
TEST_CASE("test fusion types") {
|
||||
auto cfun = compile(add3);
|
||||
auto x = array({2.0f, -2.0f});
|
||||
auto y = cfun({x})[0];
|
||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||
eval(y);
|
||||
|
||||
std::string lib_name = p->metal_lib_name();
|
||||
std::string lib_source = p->metal_lib_source();
|
||||
std::string lib_name = p->lib_name();
|
||||
CHECK(!lib_name.empty());
|
||||
CHECK(!lib_source.empty());
|
||||
|
||||
x = array({2, -2}, int32);
|
||||
auto z = cfun({x})[0];
|
||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||
eval(z);
|
||||
|
||||
std::string lib_name_z = pz->metal_lib_name();
|
||||
std::string lib_source_z = pz->metal_lib_source();
|
||||
std::string lib_name_z = pz->lib_name();
|
||||
CHECK(!lib_name_z.empty());
|
||||
CHECK(!lib_source_z.empty());
|
||||
}
|
||||
|
||||
auto compile_shapeless_not_ok(const std::vector<array>& inputs) {
|
||||
auto x = reshape(inputs[0], {2, 2});
|
||||
return std::vector<array>{x};
|
||||
}
|
||||
|
||||
auto compile_shapeless_ok(const std::vector<array>& inputs) {
|
||||
auto x = inputs[0] + array({2});
|
||||
return std::vector<array>{x};
|
||||
}
|
||||
|
||||
TEST_CASE("test shapeless compile") {
|
||||
{
|
||||
auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true);
|
||||
CHECK_THROWS(cfun({array({1, 2, 3, 4})}));
|
||||
}
|
||||
|
||||
{
|
||||
auto cfun = compile(compile_shapeless_ok, /* shapeless */ true);
|
||||
auto out = cfun({array({1, 2})})[0];
|
||||
auto out2 = cfun({array({1, 2, 3, 4})})[0];
|
||||
|
||||
// Not making a new constant array since no recompile,
|
||||
// hence the ids should be the same
|
||||
CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||
CHECK(array_equal(out2, array({3, 4, 5, 6})).item<bool>());
|
||||
|
||||
// Recompile since type changes
|
||||
out2 = cfun({array({1.0, 2.0})})[0];
|
||||
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||
|
||||
// Recompile since ndim changes
|
||||
out2 = cfun({array({1.0, 2.0}, {1, 2})})[0];
|
||||
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user