mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing * Add a dependency management primitive * Change the eval order to deep branches first * Add graph depth tracking to the array
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							143e2690d5
						
					
				
				
					commit
					0de5988f92
				
			@@ -1,5 +1,4 @@
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
#include <pybind11/functional.h>
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
#include <pybind11/stl.h>
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
@@ -142,7 +141,8 @@ std::vector<array> tree_flatten(py::object tree, bool strict = true) {
 | 
			
		||||
    if (py::isinstance<array>(obj)) {
 | 
			
		||||
      flat_tree.push_back(py::cast<array>(obj));
 | 
			
		||||
    } else if (strict) {
 | 
			
		||||
      throw std::invalid_argument("Argument is not an array");
 | 
			
		||||
      throw std::invalid_argument(
 | 
			
		||||
          "[tree_flatten] The argument should contain only arrays");
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
@@ -162,12 +162,48 @@ py::object tree_unflatten(
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
py::object tree_unflatten_none(
 | 
			
		||||
py::object structure_sentinel() {
 | 
			
		||||
  static py::object sentinel;
 | 
			
		||||
 | 
			
		||||
  if (sentinel.ptr() == nullptr) {
 | 
			
		||||
    sentinel = py::capsule(&sentinel);
 | 
			
		||||
    // probably not needed but this should make certain that we won't ever
 | 
			
		||||
    // delete the sentinel
 | 
			
		||||
    sentinel.inc_ref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return sentinel;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
 | 
			
		||||
    py::object tree,
 | 
			
		||||
    bool strict = true) {
 | 
			
		||||
  auto sentinel = structure_sentinel();
 | 
			
		||||
  std::vector<array> flat_tree;
 | 
			
		||||
  auto structure = tree_map(
 | 
			
		||||
      tree,
 | 
			
		||||
      [&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
 | 
			
		||||
        if (py::isinstance<array>(obj)) {
 | 
			
		||||
          flat_tree.push_back(py::cast<array>(obj));
 | 
			
		||||
          return sentinel;
 | 
			
		||||
        } else if (!strict) {
 | 
			
		||||
          return py::cast<py::object>(obj);
 | 
			
		||||
        } else {
 | 
			
		||||
          throw std::invalid_argument(
 | 
			
		||||
              "[tree_flatten] The argument should contain only arrays");
 | 
			
		||||
        }
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
  return {flat_tree, structure};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
py::object tree_unflatten_from_structure(
 | 
			
		||||
    py::object structure,
 | 
			
		||||
    const std::vector<array>& values,
 | 
			
		||||
    int index = 0) {
 | 
			
		||||
  return tree_map(tree, [&](py::handle obj) {
 | 
			
		||||
    if (py::isinstance<py::none>(obj)) {
 | 
			
		||||
  auto sentinel = structure_sentinel();
 | 
			
		||||
  return tree_map(structure, [&](py::handle obj) {
 | 
			
		||||
    if (obj.is(sentinel)) {
 | 
			
		||||
      return py::cast(values[index++]);
 | 
			
		||||
    } else {
 | 
			
		||||
      return py::cast<py::object>(obj);
 | 
			
		||||
@@ -472,14 +508,10 @@ struct PyCompiledFun {
 | 
			
		||||
 | 
			
		||||
  py::object operator()(const py::args& args) {
 | 
			
		||||
    auto compile_fun = [this, &args](const std::vector<array>& a) {
 | 
			
		||||
      // Call the python function
 | 
			
		||||
      py::object py_outputs = this->fun(*tree_unflatten(args, a));
 | 
			
		||||
      // Call the python function and flatten the outputs
 | 
			
		||||
      auto [outputs, py_outputs] = tree_flatten_with_structure(
 | 
			
		||||
          std::move(this->fun(*tree_unflatten(args, a))), true);
 | 
			
		||||
 | 
			
		||||
      // Flatten the outputs
 | 
			
		||||
      auto outputs = tree_flatten(py_outputs, true);
 | 
			
		||||
 | 
			
		||||
      py_outputs =
 | 
			
		||||
          tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
 | 
			
		||||
      tree_cache().insert({this->fun_id, py_outputs});
 | 
			
		||||
      return outputs;
 | 
			
		||||
    };
 | 
			
		||||
@@ -492,15 +524,75 @@ struct PyCompiledFun {
 | 
			
		||||
 | 
			
		||||
    // Put the outputs back in the container
 | 
			
		||||
    py::object py_outputs = tree_cache().at(fun_id);
 | 
			
		||||
    return tree_unflatten_none(py_outputs, outputs);
 | 
			
		||||
    return tree_unflatten_from_structure(py_outputs, outputs);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  ~PyCompiledFun() {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
    tree_cache().erase(fun_id);
 | 
			
		||||
    detail::compile_erase(fun_id);
 | 
			
		||||
    fun.release().dec_ref();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class PyCheckpointedFun {
 | 
			
		||||
 public:
 | 
			
		||||
  PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
 | 
			
		||||
 | 
			
		||||
  ~PyCheckpointedFun() {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
    fun_.release().dec_ref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  struct InnerFunction {
 | 
			
		||||
    py::object fun_;
 | 
			
		||||
    py::object args_structure_;
 | 
			
		||||
    std::weak_ptr<py::object> output_structure_;
 | 
			
		||||
 | 
			
		||||
    InnerFunction(
 | 
			
		||||
        py::object fun,
 | 
			
		||||
        py::object args_structure,
 | 
			
		||||
        std::weak_ptr<py::object> output_structure)
 | 
			
		||||
        : fun_(std::move(fun)),
 | 
			
		||||
          args_structure_(std::move(args_structure)),
 | 
			
		||||
          output_structure_(output_structure) {}
 | 
			
		||||
    ~InnerFunction() {
 | 
			
		||||
      py::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      fun_.release().dec_ref();
 | 
			
		||||
      args_structure_.release().dec_ref();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<array> operator()(const std::vector<array>& inputs) {
 | 
			
		||||
      auto args = py::cast<py::tuple>(
 | 
			
		||||
          tree_unflatten_from_structure(args_structure_, inputs));
 | 
			
		||||
      auto [outputs, output_structure] =
 | 
			
		||||
          tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
 | 
			
		||||
      if (auto s = output_structure_.lock()) {
 | 
			
		||||
        *s = output_structure;
 | 
			
		||||
      }
 | 
			
		||||
      return outputs;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  py::object operator()(const py::args& args, const py::kwargs& kwargs) {
 | 
			
		||||
    auto output_structure = std::make_shared<py::object>();
 | 
			
		||||
    auto full_args = py::make_tuple(args, kwargs);
 | 
			
		||||
    auto [inputs, args_structure] =
 | 
			
		||||
        tree_flatten_with_structure(full_args, false);
 | 
			
		||||
 | 
			
		||||
    auto outputs = checkpoint(
 | 
			
		||||
        InnerFunction(fun_, args_structure, output_structure))(inputs);
 | 
			
		||||
 | 
			
		||||
    return tree_unflatten_from_structure(*output_structure, outputs);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  py::function fun_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void init_transforms(py::module_& m) {
 | 
			
		||||
  py::options options;
 | 
			
		||||
  options.disable_function_signatures();
 | 
			
		||||
@@ -802,6 +894,10 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
        Globally enable compilation. This will override the environment
 | 
			
		||||
        variable ``MLX_DISABLE_COMPILE`` if set.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "checkpoint",
 | 
			
		||||
      [](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
 | 
			
		||||
      "fun"_a);
 | 
			
		||||
 | 
			
		||||
  // Register static Python object cleanup before the interpreter exits
 | 
			
		||||
  auto atexit = py::module_::import("atexit");
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user