mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		@@ -2,6 +2,7 @@
 | 
			
		||||
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
#include <pybind11/stl.h>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
 | 
			
		||||
#include "python/src/utils.h"
 | 
			
		||||
 | 
			
		||||
@@ -13,13 +14,55 @@ using namespace py::literals;
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
using namespace mlx::core::random;
 | 
			
		||||
 | 
			
		||||
class PyKeySequence {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit PyKeySequence(uint64_t seed) {
 | 
			
		||||
    state_.append(key(seed));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void seed(uint64_t seed) {
 | 
			
		||||
    state_[0] = key(seed);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  array next() {
 | 
			
		||||
    auto out = split(py::cast<array>(state_[0]));
 | 
			
		||||
    state_[0] = out.first;
 | 
			
		||||
    return out.second;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  py::list state() {
 | 
			
		||||
    return state_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void release() {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
    state_.release().dec_ref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  py::list state_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
PyKeySequence& default_key() {
 | 
			
		||||
  auto get_current_time_seed = []() {
 | 
			
		||||
    auto now = std::chrono::system_clock::now();
 | 
			
		||||
    return std::chrono::duration_cast<std::chrono::milliseconds>(
 | 
			
		||||
               now.time_since_epoch())
 | 
			
		||||
        .count();
 | 
			
		||||
  };
 | 
			
		||||
  static PyKeySequence ks(get_current_time_seed());
 | 
			
		||||
  return ks;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void init_random(py::module_& parent_module) {
 | 
			
		||||
  auto m = parent_module.def_submodule(
 | 
			
		||||
      "random",
 | 
			
		||||
      "mlx.core.random: functionality related to random number generation");
 | 
			
		||||
 | 
			
		||||
  m.attr("state") = default_key().state();
 | 
			
		||||
  m.def(
 | 
			
		||||
      "seed",
 | 
			
		||||
      &seed,
 | 
			
		||||
      [](uint64_t seed) { default_key().seed(seed); },
 | 
			
		||||
      "seed"_a,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Seed the global PRNG.
 | 
			
		||||
@@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
         const ScalarOrArray& high,
 | 
			
		||||
         const std::vector<int>& shape,
 | 
			
		||||
         std::optional<Dtype> type,
 | 
			
		||||
         const std::optional<array>& key,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        return uniform(
 | 
			
		||||
            to_array(low),
 | 
			
		||||
            to_array(high),
 | 
			
		||||
@@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
         std::optional<Dtype> type,
 | 
			
		||||
         float loc,
 | 
			
		||||
         float scale,
 | 
			
		||||
         const std::optional<array>& key,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        return normal(shape, type.value_or(float32), loc, scale, key, s);
 | 
			
		||||
      },
 | 
			
		||||
 | 
			
		||||
      "shape"_a = std::vector<int>{},
 | 
			
		||||
      "dtype"_a = std::optional{float32},
 | 
			
		||||
      "loc"_a = 0.0,
 | 
			
		||||
@@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
         const ScalarOrArray& high,
 | 
			
		||||
         const std::vector<int>& shape,
 | 
			
		||||
         std::optional<Dtype> type,
 | 
			
		||||
         const std::optional<array>& key,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        return randint(
 | 
			
		||||
            to_array(low), to_array(high), shape, type.value_or(int32), key, s);
 | 
			
		||||
      },
 | 
			
		||||
@@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
      "bernoulli",
 | 
			
		||||
      [](const ScalarOrArray& p_,
 | 
			
		||||
         const std::optional<std::vector<int>> shape,
 | 
			
		||||
         const std::optional<array>& key,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        auto p = to_array(p_);
 | 
			
		||||
        if (shape.has_value()) {
 | 
			
		||||
          return bernoulli(p, shape.value(), key, s);
 | 
			
		||||
@@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
         const ScalarOrArray& upper_,
 | 
			
		||||
         const std::optional<std::vector<int>> shape_,
 | 
			
		||||
         std::optional<Dtype> type,
 | 
			
		||||
         const std::optional<array>& key,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        auto lower = to_array(lower_);
 | 
			
		||||
        auto upper = to_array(upper_);
 | 
			
		||||
        auto t = type.value_or(float32);
 | 
			
		||||
@@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
      "gumbel",
 | 
			
		||||
      [](const std::vector<int>& shape,
 | 
			
		||||
         std::optional<Dtype> type,
 | 
			
		||||
         const std::optional<array>& key,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        return gumbel(shape, type.value_or(float32), key, s);
 | 
			
		||||
      },
 | 
			
		||||
      "shape"_a = std::vector<int>{},
 | 
			
		||||
@@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
         int axis,
 | 
			
		||||
         const std::optional<std::vector<int>> shape,
 | 
			
		||||
         const std::optional<int> num_samples,
 | 
			
		||||
         const std::optional<array>& key,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        if (shape.has_value() && num_samples.has_value()) {
 | 
			
		||||
          throw std::invalid_argument(
 | 
			
		||||
              "[categorical] At most one of shape or num_samples can be specified.");
 | 
			
		||||
@@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The ``shape``-sized output array with type ``uint32``.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  // Register static Python object cleanup before the interpreter exits
 | 
			
		||||
  auto atexit = py::module_::import("atexit");
 | 
			
		||||
  atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -135,6 +135,64 @@ py::object tree_map(
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void tree_visit_update(
 | 
			
		||||
    py::object tree,
 | 
			
		||||
    std::function<py::object(py::handle)> visitor) {
 | 
			
		||||
  std::function<py::object(py::handle)> recurse;
 | 
			
		||||
  recurse = [&](py::handle subtree) {
 | 
			
		||||
    if (py::isinstance<py::list>(subtree)) {
 | 
			
		||||
      auto l = py::cast<py::list>(subtree);
 | 
			
		||||
      for (int i = 0; i < l.size(); ++i) {
 | 
			
		||||
        l[i] = recurse(l[i]);
 | 
			
		||||
      }
 | 
			
		||||
      return py::cast<py::object>(l);
 | 
			
		||||
    } else if (py::isinstance<py::tuple>(subtree)) {
 | 
			
		||||
      for (auto item : subtree) {
 | 
			
		||||
        recurse(item);
 | 
			
		||||
      }
 | 
			
		||||
      return py::cast<py::object>(subtree);
 | 
			
		||||
    } else if (py::isinstance<py::dict>(subtree)) {
 | 
			
		||||
      auto d = py::cast<py::dict>(subtree);
 | 
			
		||||
      for (auto item : d) {
 | 
			
		||||
        d[item.first] = recurse(item.second);
 | 
			
		||||
      }
 | 
			
		||||
      return py::cast<py::object>(d);
 | 
			
		||||
    } else if (py::isinstance<array>(subtree)) {
 | 
			
		||||
      return visitor(subtree);
 | 
			
		||||
    } else {
 | 
			
		||||
      return py::cast<py::object>(subtree);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  recurse(tree);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fill a pytree (recursive dict or list of dict or list)
 | 
			
		||||
// in place with the given arrays
 | 
			
		||||
// Non dict or list nodes are ignored
 | 
			
		||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
 | 
			
		||||
  size_t index = 0;
 | 
			
		||||
  tree_visit_update(
 | 
			
		||||
      tree, [&](py::handle node) { return py::cast(values[index++]); });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Replace all the arrays from the src values with the dst values in the tree
 | 
			
		||||
void tree_replace(
 | 
			
		||||
    py::object& tree,
 | 
			
		||||
    const std::vector<array>& src,
 | 
			
		||||
    const std::vector<array>& dst) {
 | 
			
		||||
  std::unordered_map<uintptr_t, array> src_to_dst;
 | 
			
		||||
  for (int i = 0; i < src.size(); ++i) {
 | 
			
		||||
    src_to_dst.insert({src[i].id(), dst[i]});
 | 
			
		||||
  }
 | 
			
		||||
  tree_visit_update(tree, [&](py::handle node) {
 | 
			
		||||
    auto arr = py::cast<array>(node);
 | 
			
		||||
    if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
 | 
			
		||||
      return py::cast(it->second);
 | 
			
		||||
    }
 | 
			
		||||
    return py::cast(arr);
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
 | 
			
		||||
  std::vector<array> flat_tree;
 | 
			
		||||
 | 
			
		||||
@@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() {
 | 
			
		||||
struct PyCompiledFun {
 | 
			
		||||
  py::function fun;
 | 
			
		||||
  size_t fun_id;
 | 
			
		||||
  py::object captured_inputs;
 | 
			
		||||
  py::object captured_outputs;
 | 
			
		||||
  size_t num_outputs{0};
 | 
			
		||||
 | 
			
		||||
  PyCompiledFun(const py::function& fun)
 | 
			
		||||
      : fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
 | 
			
		||||
  PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
 | 
			
		||||
      : fun(fun),
 | 
			
		||||
        fun_id(reinterpret_cast<size_t>(fun.ptr())),
 | 
			
		||||
        captured_inputs(inputs),
 | 
			
		||||
        captured_outputs(outputs) {}
 | 
			
		||||
 | 
			
		||||
  PyCompiledFun(const PyCompiledFun&) = delete;
 | 
			
		||||
  PyCompiledFun& operator=(const PyCompiledFun&) = delete;
 | 
			
		||||
@@ -505,23 +569,61 @@ struct PyCompiledFun {
 | 
			
		||||
  PyCompiledFun(PyCompiledFun&& other)
 | 
			
		||||
      : fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
 | 
			
		||||
    other.fun_id = 0;
 | 
			
		||||
    captured_inputs = std::move(other.captured_inputs);
 | 
			
		||||
    captured_outputs = std::move(other.captured_outputs);
 | 
			
		||||
    num_outputs = other.num_outputs;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  py::object operator()(const py::args& args) {
 | 
			
		||||
    auto compile_fun = [this, &args](const std::vector<array>& a) {
 | 
			
		||||
      // Call the python function and flatten the outputs
 | 
			
		||||
      auto [outputs, py_outputs] = tree_flatten_with_structure(
 | 
			
		||||
          std::move(this->fun(*tree_unflatten(args, a))), true);
 | 
			
		||||
      // Put tracers into captured inputs
 | 
			
		||||
      std::vector<array> flat_in_captures;
 | 
			
		||||
      std::vector<array> trace_captures;
 | 
			
		||||
      if (!py::isinstance<py::none>(captured_inputs)) {
 | 
			
		||||
        flat_in_captures = tree_flatten(captured_inputs, false);
 | 
			
		||||
        trace_captures.insert(
 | 
			
		||||
            trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
 | 
			
		||||
        tree_fill(captured_inputs, trace_captures);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      tree_cache().insert({this->fun_id, py_outputs});
 | 
			
		||||
      auto [outputs, py_outputs] = tree_flatten_with_structure(
 | 
			
		||||
          std::move(fun(*tree_unflatten(args, a))), false);
 | 
			
		||||
 | 
			
		||||
      tree_cache().insert({fun_id, py_outputs});
 | 
			
		||||
 | 
			
		||||
      num_outputs = outputs.size();
 | 
			
		||||
      if (!py::isinstance<py::none>(captured_outputs)) {
 | 
			
		||||
        auto flat_out_captures = tree_flatten(captured_outputs, false);
 | 
			
		||||
        outputs.insert(
 | 
			
		||||
            outputs.end(),
 | 
			
		||||
            std::make_move_iterator(flat_out_captures.begin()),
 | 
			
		||||
            std::make_move_iterator(flat_out_captures.end()));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Replace tracers with originals in captured inputs
 | 
			
		||||
      if (!py::isinstance<py::none>(captured_inputs)) {
 | 
			
		||||
        tree_replace(captured_inputs, trace_captures, flat_in_captures);
 | 
			
		||||
      }
 | 
			
		||||
      return outputs;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Inputs must be array or tree of arrays
 | 
			
		||||
    auto inputs = tree_flatten(args, true);
 | 
			
		||||
    auto inputs = tree_flatten(args, false);
 | 
			
		||||
    if (!py::isinstance<py::none>(captured_inputs)) {
 | 
			
		||||
      auto flat_in_captures = tree_flatten(captured_inputs, false);
 | 
			
		||||
      inputs.insert(
 | 
			
		||||
          inputs.end(),
 | 
			
		||||
          std::make_move_iterator(flat_in_captures.begin()),
 | 
			
		||||
          std::make_move_iterator(flat_in_captures.end()));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Compile and call
 | 
			
		||||
    auto outputs = detail::compile(compile_fun, fun_id)(inputs);
 | 
			
		||||
    if (!py::isinstance<py::none>(captured_outputs)) {
 | 
			
		||||
      std::vector<array> captures(
 | 
			
		||||
          std::make_move_iterator(outputs.begin() + num_outputs),
 | 
			
		||||
          std::make_move_iterator(outputs.end()));
 | 
			
		||||
      tree_fill(captured_outputs, captures);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Put the outputs back in the container
 | 
			
		||||
    py::object py_outputs = tree_cache().at(fun_id);
 | 
			
		||||
@@ -534,6 +636,8 @@ struct PyCompiledFun {
 | 
			
		||||
    tree_cache().erase(fun_id);
 | 
			
		||||
    detail::compile_erase(fun_id);
 | 
			
		||||
    fun.release().dec_ref();
 | 
			
		||||
    captured_inputs.release().dec_ref();
 | 
			
		||||
    captured_outputs.release().dec_ref();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -601,7 +705,7 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
  m.def(
 | 
			
		||||
      "eval",
 | 
			
		||||
      [](const py::args& args) {
 | 
			
		||||
        std::vector<array> arrays = tree_flatten(args);
 | 
			
		||||
        std::vector<array> arrays = tree_flatten(args, false);
 | 
			
		||||
        {
 | 
			
		||||
          py::gil_scoped_release nogil;
 | 
			
		||||
          eval(arrays);
 | 
			
		||||
@@ -615,8 +719,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
        Args:
 | 
			
		||||
            *args (arrays or trees of arrays): Each argument can be a single array
 | 
			
		||||
              or a tree of arrays. If a tree is given the nodes can be a Python
 | 
			
		||||
              :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
 | 
			
		||||
              an :class:`array`.
 | 
			
		||||
              :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
 | 
			
		||||
              arrays are ignored.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "jvp",
 | 
			
		||||
@@ -859,10 +963,14 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
      "file"_a);
 | 
			
		||||
  m.def(
 | 
			
		||||
      "compile",
 | 
			
		||||
      [](const py::function& fun) {
 | 
			
		||||
        return py::cpp_function(PyCompiledFun{fun});
 | 
			
		||||
      [](const py::function& fun,
 | 
			
		||||
         const py::object& inputs,
 | 
			
		||||
         const py::object& outputs) {
 | 
			
		||||
        return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
 | 
			
		||||
      },
 | 
			
		||||
      "fun"_a,
 | 
			
		||||
      "inputs"_a = std::nullopt,
 | 
			
		||||
      "outputs"_a = std::nullopt,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        compile(fun: function) -> function
 | 
			
		||||
 | 
			
		||||
@@ -872,6 +980,16 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
            fun (function): A function which takes a variable number of
 | 
			
		||||
              :class:`array` or trees of :class:`array` and returns
 | 
			
		||||
              a variable number of :class:`array` or trees of :class:`array`.
 | 
			
		||||
            inputs (list or dict, optional): These inputs will be captured during
 | 
			
		||||
              the function compilation along with the inputs to ``fun``. The ``inputs``
 | 
			
		||||
              can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested
 | 
			
		||||
              lists, dictionaries, or arrays. Leaf nodes that are not
 | 
			
		||||
              :obj:`array` are ignored. Default: ``None``
 | 
			
		||||
            outputs (list or dict, optional): These outputs will be captured and
 | 
			
		||||
              updated in a compiled function. The ``outputs`` can be a
 | 
			
		||||
              :obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
 | 
			
		||||
              dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
 | 
			
		||||
              Default: ``None``
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            function: A compiled function which has the same input arguments
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user