fix grad, more tests

This commit is contained in:
Awni Hannun 2024-01-14 19:42:36 -08:00
parent 9d70412a91
commit 9739c72781
2 changed files with 47 additions and 27 deletions

View File

@ -107,10 +107,8 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
std::unordered_set<std::uintptr_t> cache; std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>> std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
parents_map; parents_map;
std::unordered_set<std::uintptr_t> needs_compile;
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {
auto in = inputs[i]; auto in = inputs[i];
needs_compile.insert(in.id());
cache.insert(in.id()); cache.insert(in.id());
} }
@ -132,16 +130,7 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
for (auto& s : a.siblings()) { for (auto& s : a.siblings()) {
cache.insert(s.id()); cache.insert(s.id());
} }
for (auto& input : a.inputs()) { tape.push_back(a);
if (needs_compile.find(input.id()) != needs_compile.end()) {
tape.push_back(a);
needs_compile.insert(a.id());
for (auto s : a.siblings()) {
needs_compile.insert(s.id());
}
break;
}
}
}; };
for (auto& a : outputs) { for (auto& a : outputs) {
recurse(a); recurse(a);
@ -254,7 +243,7 @@ void compile_simplify(
for (auto& o : outputs) { for (auto& o : outputs) {
output_set.insert(o.id()); output_set.insert(o.id());
} }
// Pass 1 to passes: fuse only keeping non-orphaned arrays in the tape // Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) { for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) { for (auto& arr : tape) {
// Helper to check if we can fuse the parents of the // Helper to check if we can fuse the parents of the
@ -312,18 +301,21 @@ std::vector<array> compile_replace(
} }
for (auto& a : tape) { for (auto& a : tape) {
// Arrays in the tape without primitives are constants
// and can be used directly
if (!a.has_primitive()) { if (!a.has_primitive()) {
std::runtime_error( trace_to_real.insert({a.id(), a});
"[compile] Something went wrong, no primitive in tape"); } else {
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
} }
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
auto real_a =
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
} }
std::vector<array> outputs; std::vector<array> outputs;
for (auto& o : trace_outputs) { for (auto& o : trace_outputs) {
outputs.push_back(trace_to_real.at(o.id())); outputs.push_back(trace_to_real.at(o.id()));
@ -342,8 +334,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
if (entry.empty) { if (entry.empty) {
// Mark the entry as not empty since we are about to fill it // Mark the entry as not empty since we are about to fill it
entry.empty = false; entry.empty = false;
// Trace to build the graph
// Trace te build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
// DFS the graph and get a tape, and a map of array id to (parent, // DFS the graph and get a tape, and a map of array id to (parent,

View File

@ -8,13 +8,42 @@ using namespace mlx::core;
std::vector<array> simple_fun(const std::vector<array>& inputs) { std::vector<array> simple_fun(const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] + inputs[1]}; return std::vector<array>{inputs[0] + inputs[1]};
}; }
TEST_CASE("test simple compile") { TEST_CASE("test simple compile") {
auto compfn = compile(simple_fun); auto compfn = compile(simple_fun);
auto out = compfn({array(1.0), array(2.0)})[0]; auto out = compfn({array(1.0f), array(2.0f)})[0];
CHECK_EQ(out.item<float>(), 3.0f); CHECK_EQ(out.item<float>(), 3.0f);
out = compfn({array(1.0), array(2.0)})[0]; out = compfn({array(1.0f), array(2.0f)})[0];
CHECK_EQ(out.item<float>(), 3.0f); CHECK_EQ(out.item<float>(), 3.0f);
// Change the shapes
out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0];
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
out = compfn({array(2.0f), array({1.0f, 2.0f})})[0];
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
// Change the types
out = compfn({array(2, int32), array({1.0f, 2.0f})})[0];
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
out = compfn({array(2.0f), array({1, 2}, int32)})[0];
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
} }
std::vector<array> fun1(const std::vector<array>& inputs) {
auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
return grad(loss)(inputs);
}
TEST_CASE("test compile with grad") {
auto x = array(1.0f);
auto y = array(1.0f);
auto grads_expected = fun1({x, y});
auto grads_compile = compile(fun1)({x, y});
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
}
TEST_CASE("test nested compile") {}