mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
remove simplify
This commit is contained in:
@@ -25,7 +25,6 @@ target_sources(tests PRIVATE
|
||||
device_tests.cpp
|
||||
eval_tests.cpp
|
||||
fft_tests.cpp
|
||||
graph_optimize_tests.cpp
|
||||
load_tests.cpp
|
||||
ops_tests.cpp
|
||||
random_tests.cpp
|
||||
|
||||
@@ -104,3 +104,77 @@ TEST_CASE("test enable and disable compile") {
|
||||
enable_compiler();
|
||||
CHECK_THROWS(compile(nullptr));
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
{
|
||||
auto a = array(-1.0f);
|
||||
auto b = array(-1.0f);
|
||||
auto c = abs(a);
|
||||
auto d = abs(b);
|
||||
simplify({c, d});
|
||||
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO rework these tests for compile
|
||||
/*TEST_CASE("test simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = exp(a) + exp(a);
|
||||
simplify(b);
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test no simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = cos(a) + sin(a);
|
||||
simplify(b);
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify multi output") {
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = c[0] + d[0];
|
||||
auto f = c[1] + d[1];
|
||||
|
||||
simplify({e, f});
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
||||
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(1.0);
|
||||
auto c = divmod(a, b);
|
||||
simplify(c);
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
||||
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
||||
}
|
||||
|
||||
// Make sure the output order of multi-output primitives
|
||||
// is respected in simplification
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||
simplify(e);
|
||||
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||
}
|
||||
}*/
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
{
|
||||
auto a = array(-1.0f);
|
||||
auto b = array(-1.0f);
|
||||
auto c = abs(a);
|
||||
auto d = abs(b);
|
||||
simplify({c, d});
|
||||
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = exp(a) + exp(a);
|
||||
simplify(b);
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test no simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = cos(a) + sin(a);
|
||||
simplify(b);
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify multi output") {
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = c[0] + d[0];
|
||||
auto f = c[1] + d[1];
|
||||
|
||||
simplify({e, f});
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
||||
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(1.0);
|
||||
auto c = divmod(a, b);
|
||||
simplify(c);
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
||||
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
||||
}
|
||||
|
||||
// Make sure the output order of multi-output primitives
|
||||
// is respected in simplification
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||
simplify(e);
|
||||
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user