mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
fix grad, more tests
This commit is contained in:
parent
9d70412a91
commit
9739c72781
@ -107,10 +107,8 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||
parents_map;
|
||||
std::unordered_set<std::uintptr_t> needs_compile;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto in = inputs[i];
|
||||
needs_compile.insert(in.id());
|
||||
cache.insert(in.id());
|
||||
}
|
||||
|
||||
@ -132,16 +130,7 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
for (auto& input : a.inputs()) {
|
||||
if (needs_compile.find(input.id()) != needs_compile.end()) {
|
||||
tape.push_back(a);
|
||||
needs_compile.insert(a.id());
|
||||
for (auto s : a.siblings()) {
|
||||
needs_compile.insert(s.id());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
for (auto& a : outputs) {
|
||||
recurse(a);
|
||||
@ -254,7 +243,7 @@ void compile_simplify(
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
}
|
||||
// Pass 1 to passes: fuse only keeping non-orphaned arrays in the tape
|
||||
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
|
||||
for (int pass = 0; pass < passes; ++pass) {
|
||||
for (auto& arr : tape) {
|
||||
// Helper to check if we can fuse the parents of the
|
||||
@ -312,18 +301,21 @@ std::vector<array> compile_replace(
|
||||
}
|
||||
|
||||
for (auto& a : tape) {
|
||||
// Arrays in the tape without primitives are constants
|
||||
// and can be used directly
|
||||
if (!a.has_primitive()) {
|
||||
std::runtime_error(
|
||||
"[compile] Something went wrong, no primitive in tape");
|
||||
}
|
||||
trace_to_real.insert({a.id(), a});
|
||||
} else {
|
||||
std::vector<array> real_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
auto real_a =
|
||||
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||
auto real_a = array(
|
||||
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
for (auto& o : trace_outputs) {
|
||||
outputs.push_back(trace_to_real.at(o.id()));
|
||||
@ -342,8 +334,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (entry.empty) {
|
||||
// Mark the entry as not empty since we are about to fill it
|
||||
entry.empty = false;
|
||||
|
||||
// Trace te build the graph
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
|
@ -8,13 +8,42 @@ using namespace mlx::core;
|
||||
|
||||
std::vector<array> simple_fun(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + inputs[1]};
|
||||
};
|
||||
}
|
||||
|
||||
TEST_CASE("test simple compile") {
|
||||
auto compfn = compile(simple_fun);
|
||||
auto out = compfn({array(1.0), array(2.0)})[0];
|
||||
auto out = compfn({array(1.0f), array(2.0f)})[0];
|
||||
CHECK_EQ(out.item<float>(), 3.0f);
|
||||
|
||||
out = compfn({array(1.0), array(2.0)})[0];
|
||||
out = compfn({array(1.0f), array(2.0f)})[0];
|
||||
CHECK_EQ(out.item<float>(), 3.0f);
|
||||
|
||||
// Change the shapes
|
||||
out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
|
||||
out = compfn({array(2.0f), array({1.0f, 2.0f})})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
|
||||
// Change the types
|
||||
out = compfn({array(2, int32), array({1.0f, 2.0f})})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
|
||||
out = compfn({array(2.0f), array({1, 2}, int32)})[0];
|
||||
CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
|
||||
}
|
||||
|
||||
std::vector<array> fun1(const std::vector<array>& inputs) {
|
||||
auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
|
||||
return grad(loss)(inputs);
|
||||
}
|
||||
|
||||
TEST_CASE("test compile with grad") {
|
||||
auto x = array(1.0f);
|
||||
auto y = array(1.0f);
|
||||
auto grads_expected = fun1({x, y});
|
||||
auto grads_compile = compile(fun1)({x, y});
|
||||
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
|
||||
}
|
||||
|
||||
TEST_CASE("test nested compile") {}
|
||||
|
Loading…
Reference in New Issue
Block a user