mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-30 22:51:24 +08:00
simplify tests use compile now
This commit is contained in:
parent
1c3f82ca17
commit
ed4d867092
@ -4,6 +4,7 @@
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
@ -92,12 +93,12 @@ struct CompilerCache {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
CompilerCache() {}
|
||||
CompilerCache() {
|
||||
// Make sure the allocator is fully
|
||||
// initialized before the compiler cache
|
||||
allocator::allocator();
|
||||
}
|
||||
friend CompilerCache& compiler_cache();
|
||||
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
||||
};
|
||||
@ -400,10 +401,6 @@ void compile_erase(size_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
void compile_clear() {
|
||||
detail::compiler_cache().clear();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
|
@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <future>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
@ -23,9 +23,6 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Erase cached compile functions
|
||||
void compile_erase(size_t fun_id);
|
||||
|
||||
// Clear the compiler cache
|
||||
void compile_clear();
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
|
@ -836,8 +836,5 @@ void init_transforms(py::module_& m) {
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() {
|
||||
detail::compile_clear();
|
||||
tree_cache().clear();
|
||||
}));
|
||||
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||
}
|
||||
|
@ -105,60 +105,98 @@ TEST_CASE("test enable and disable compile") {
|
||||
CHECK_THROWS(compile(nullptr));
|
||||
}
|
||||
|
||||
auto add_scalars(const std::vector<array>&) {
|
||||
auto a = array(-1.0f);
|
||||
auto b = array(-1.0f);
|
||||
return std::vector<array>{abs(a), abs(b)};
|
||||
};
|
||||
|
||||
auto max_scalars(const std::vector<array>&) {
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
return std::vector<array>{b, c, d};
|
||||
};
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
{
|
||||
auto a = array(-1.0f);
|
||||
auto b = array(-1.0f);
|
||||
auto c = abs(a);
|
||||
auto d = abs(b);
|
||||
simplify({c, d});
|
||||
auto cfun = compile(add_scalars);
|
||||
auto out = cfun({});
|
||||
auto c = out[0];
|
||||
auto d = out[1];
|
||||
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
auto out = compile(max_scalars)({a});
|
||||
auto b = out[0];
|
||||
auto c = out[1];
|
||||
auto d = out[2];
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO rework these tests for compile
|
||||
/*TEST_CASE("test simplify") {
|
||||
auto exp_two(const std::vector<array>& inputs) {
|
||||
auto a = inputs[0];
|
||||
return std::vector<array>{exp(a) + exp(a)};
|
||||
};
|
||||
|
||||
TEST_CASE("test simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = exp(a) + exp(a);
|
||||
simplify(b);
|
||||
auto b = compile(exp_two)({a})[0];
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
}
|
||||
|
||||
auto add_diff(const std::vector<array>& inputs) {
|
||||
auto a = inputs[0];
|
||||
return std::vector<array>{cos(a) + sin(a)};
|
||||
};
|
||||
|
||||
TEST_CASE("test no simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = cos(a) + sin(a);
|
||||
simplify(b);
|
||||
auto b = compile(add_diff)({a})[0];
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
}
|
||||
|
||||
auto multi_one(const std::vector<array>&) {
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = c[0] + d[0];
|
||||
auto f = c[1] + d[1];
|
||||
return std::vector<array>{e, f};
|
||||
}
|
||||
|
||||
auto multi_two(const std::vector<array>&) {
|
||||
auto a = array(1.0);
|
||||
auto b = array(1.0);
|
||||
auto c = divmod(a, b);
|
||||
return std::vector<array>{c};
|
||||
}
|
||||
|
||||
auto multi_three(const std::vector<array>&) {
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||
return std::vector<array>{e};
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify multi output") {
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = c[0] + d[0];
|
||||
auto f = c[1] + d[1];
|
||||
|
||||
simplify({e, f});
|
||||
auto out = compile(multi_one)({});
|
||||
auto e = out[0];
|
||||
auto f = out[1];
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
||||
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(1.0);
|
||||
auto c = divmod(a, b);
|
||||
simplify(c);
|
||||
auto c = compile(multi_two)({});
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
||||
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
||||
@ -167,14 +205,9 @@ TEST_CASE("test simplify multi output") {
|
||||
// Make sure the output order of multi-output primitives
|
||||
// is respected in simplification
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||
simplify(e);
|
||||
auto e = compile(multi_three)({})[0];
|
||||
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||
}
|
||||
}*/
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user