mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	fix grad, more tests
This commit is contained in:
		@@ -107,10 +107,8 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
 | 
			
		||||
  std::unordered_set<std::uintptr_t> cache;
 | 
			
		||||
  std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
 | 
			
		||||
      parents_map;
 | 
			
		||||
  std::unordered_set<std::uintptr_t> needs_compile;
 | 
			
		||||
  for (int i = 0; i < inputs.size(); ++i) {
 | 
			
		||||
    auto in = inputs[i];
 | 
			
		||||
    needs_compile.insert(in.id());
 | 
			
		||||
    cache.insert(in.id());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -132,16 +130,7 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
 | 
			
		||||
    for (auto& s : a.siblings()) {
 | 
			
		||||
      cache.insert(s.id());
 | 
			
		||||
    }
 | 
			
		||||
    for (auto& input : a.inputs()) {
 | 
			
		||||
      if (needs_compile.find(input.id()) != needs_compile.end()) {
 | 
			
		||||
        tape.push_back(a);
 | 
			
		||||
        needs_compile.insert(a.id());
 | 
			
		||||
        for (auto s : a.siblings()) {
 | 
			
		||||
          needs_compile.insert(s.id());
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    tape.push_back(a);
 | 
			
		||||
  };
 | 
			
		||||
  for (auto& a : outputs) {
 | 
			
		||||
    recurse(a);
 | 
			
		||||
@@ -254,7 +243,7 @@ void compile_simplify(
 | 
			
		||||
  for (auto& o : outputs) {
 | 
			
		||||
    output_set.insert(o.id());
 | 
			
		||||
  }
 | 
			
		||||
  // Pass 1 to passes: fuse only keeping non-orphaned arrays in the tape
 | 
			
		||||
  // Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
 | 
			
		||||
  for (int pass = 0; pass < passes; ++pass) {
 | 
			
		||||
    for (auto& arr : tape) {
 | 
			
		||||
      // Helper to check if we can fuse the parents of the
 | 
			
		||||
@@ -312,18 +301,21 @@ std::vector<array> compile_replace(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (auto& a : tape) {
 | 
			
		||||
    // Arrays in the tape without primitives are constants
 | 
			
		||||
    // and can be used directly
 | 
			
		||||
    if (!a.has_primitive()) {
 | 
			
		||||
      std::runtime_error(
 | 
			
		||||
          "[compile] Something went wrong, no primitive in tape");
 | 
			
		||||
      trace_to_real.insert({a.id(), a});
 | 
			
		||||
    } else {
 | 
			
		||||
      std::vector<array> real_inputs;
 | 
			
		||||
      for (auto& in : a.inputs()) {
 | 
			
		||||
        real_inputs.push_back(trace_to_real.at(in.id()));
 | 
			
		||||
      }
 | 
			
		||||
      auto real_a = array(
 | 
			
		||||
          a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
 | 
			
		||||
      trace_to_real.insert({a.id(), std::move(real_a)});
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<array> real_inputs;
 | 
			
		||||
    for (auto& in : a.inputs()) {
 | 
			
		||||
      real_inputs.push_back(trace_to_real.at(in.id()));
 | 
			
		||||
    }
 | 
			
		||||
    auto real_a =
 | 
			
		||||
        array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
 | 
			
		||||
    trace_to_real.insert({a.id(), std::move(real_a)});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<array> outputs;
 | 
			
		||||
  for (auto& o : trace_outputs) {
 | 
			
		||||
    outputs.push_back(trace_to_real.at(o.id()));
 | 
			
		||||
@@ -342,8 +334,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
    if (entry.empty) {
 | 
			
		||||
      // Mark the entry as not empty since we are about to fill it
 | 
			
		||||
      entry.empty = false;
 | 
			
		||||
 | 
			
		||||
      // Trace te build the graph
 | 
			
		||||
      // Trace to build the graph
 | 
			
		||||
      std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
 | 
			
		||||
 | 
			
		||||
      // DFS the graph and get a tape, and a map of array id to (parent,
 | 
			
		||||
 
 | 
			
		||||
@@ -8,13 +8,42 @@ using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
std::vector<array> simple_fun(const std::vector<array>& inputs) {
 | 
			
		||||
  return std::vector<array>{inputs[0] + inputs[1]};
 | 
			
		||||
};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_CASE("test simple compile") {
 | 
			
		||||
  auto compfn = compile(simple_fun);
 | 
			
		||||
  auto out = compfn({array(1.0), array(2.0)})[0];
 | 
			
		||||
  auto out = compfn({array(1.0f), array(2.0f)})[0];
 | 
			
		||||
  CHECK_EQ(out.item<float>(), 3.0f);
 | 
			
		||||
 | 
			
		||||
  out = compfn({array(1.0), array(2.0)})[0];
 | 
			
		||||
  out = compfn({array(1.0f), array(2.0f)})[0];
 | 
			
		||||
  CHECK_EQ(out.item<float>(), 3.0f);
 | 
			
		||||
 | 
			
		||||
  // Change the shapes
 | 
			
		||||
  out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0];
 | 
			
		||||
  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
 | 
			
		||||
 | 
			
		||||
  out = compfn({array(2.0f), array({1.0f, 2.0f})})[0];
 | 
			
		||||
  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
 | 
			
		||||
 | 
			
		||||
  // Change the types
 | 
			
		||||
  out = compfn({array(2, int32), array({1.0f, 2.0f})})[0];
 | 
			
		||||
  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
 | 
			
		||||
 | 
			
		||||
  out = compfn({array(2.0f), array({1, 2}, int32)})[0];
 | 
			
		||||
  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<array> fun1(const std::vector<array>& inputs) {
 | 
			
		||||
  auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };
 | 
			
		||||
  return grad(loss)(inputs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_CASE("test compile with grad") {
 | 
			
		||||
  auto x = array(1.0f);
 | 
			
		||||
  auto y = array(1.0f);
 | 
			
		||||
  auto grads_expected = fun1({x, y});
 | 
			
		||||
  auto grads_compile = compile(fun1)({x, y});
 | 
			
		||||
  CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_CASE("test nested compile") {}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user