mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Multi output primitives (#330)
* Multi-output primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -7,19 +7,29 @@
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
{
|
||||
auto a = array(-1.0f);
|
||||
auto b = array(-1.0f);
|
||||
auto c = abs(a);
|
||||
auto d = abs(b);
|
||||
simplify({c, d});
|
||||
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = exp(a) + exp(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
}
|
||||
|
||||
@@ -27,6 +37,44 @@ TEST_CASE("test no simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = cos(a) + sin(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify multi output") {
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = c[0] + d[0];
|
||||
auto f = c[1] + d[1];
|
||||
|
||||
simplify({e, f});
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
||||
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(1.0);
|
||||
auto c = divmod(a, b);
|
||||
simplify(c);
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
||||
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
||||
}
|
||||
|
||||
// Make sure the output order of multi-output primitives
|
||||
// is respected in simplification
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||
simplify(e);
|
||||
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||
}
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||
CHECK_EQ(
|
||||
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
|
||||
CHECK(array_equal(
|
||||
CHECK(allclose(
|
||||
norm(x, 0, false),
|
||||
array(
|
||||
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||
|
@@ -2397,3 +2397,58 @@ TEST_CASE("inner") {
|
||||
expected = array({7., 0., 0., 7.}, {2, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test divmod") {
|
||||
auto x = array({1, 2, 3});
|
||||
auto y = array({1, 1, 1});
|
||||
auto out = divmod(x, y);
|
||||
CHECK(array_equal(out[0], array({1, 2, 3})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({0, 0, 0})).item<bool>());
|
||||
|
||||
x = array({5, 6, 7});
|
||||
y = array({2, 2, 2});
|
||||
out = divmod(x, y);
|
||||
CHECK(array_equal(out[0], array({2, 3, 3})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({1, 0, 1})).item<bool>());
|
||||
|
||||
// Siblings should be gone after evaling the graph
|
||||
CHECK(out[0].siblings().empty());
|
||||
CHECK(out[1].siblings().empty());
|
||||
|
||||
x = array({5.0, 6.0, 7.0});
|
||||
y = array({2.0, 2.0, 2.0});
|
||||
out = divmod(x, y);
|
||||
CHECK(array_equal(out[0], array({2.0, 3.0, 3.0})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({1.0, 0.0, 1.0})).item<bool>());
|
||||
|
||||
x = array({1.0}, complex64);
|
||||
y = array({2.0}, complex64);
|
||||
CHECK_THROWS(divmod(x, y));
|
||||
|
||||
// Check that we can eval on both outputs
|
||||
x = array({1.0});
|
||||
y = array({2.0});
|
||||
out = divmod(x, y);
|
||||
eval(out);
|
||||
CHECK_EQ(out[0].item<float>(), 0.0);
|
||||
CHECK_EQ(out[1].item<float>(), 1.0);
|
||||
|
||||
// Check nested in the graph
|
||||
x = array({1.0});
|
||||
y = array({2.0});
|
||||
out = divmod(x, y);
|
||||
auto z = out[0] + out[1];
|
||||
CHECK_EQ(z.item<float>(), 1.0);
|
||||
|
||||
// Check that we can still eval when one output goes out of scope
|
||||
std::vector<array> out_holder;
|
||||
{ out_holder.push_back(divmod(x, y)[0]); }
|
||||
eval(out_holder);
|
||||
CHECK_EQ(out_holder[0].item<float>(), 0.0);
|
||||
|
||||
// Check that we can still eval when the other output goes out of scope
|
||||
out_holder.clear();
|
||||
{ out_holder.push_back(divmod(x, y)[1]); }
|
||||
eval(out_holder);
|
||||
CHECK_EQ(out_holder[0].item<float>(), 1.0);
|
||||
}
|
||||
|
Reference in New Issue
Block a user