Compile with capture (#629)

* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-07 17:29:22 -08:00
committed by GitHub
parent e5e816a5ef
commit 1b97b2958b
13 changed files with 723 additions and 157 deletions

View File

@@ -2,6 +2,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <chrono>
#include "python/src/utils.h"
@@ -13,13 +14,55 @@ using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::random;
class PyKeySequence {
public:
explicit PyKeySequence(uint64_t seed) {
state_.append(key(seed));
}
void seed(uint64_t seed) {
state_[0] = key(seed);
}
array next() {
auto out = split(py::cast<array>(state_[0]));
state_[0] = out.first;
return out.second;
}
py::list state() {
return state_;
}
void release() {
py::gil_scoped_acquire gil;
state_.release().dec_ref();
}
private:
py::list state_;
};
PyKeySequence& default_key() {
auto get_current_time_seed = []() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
};
static PyKeySequence ks(get_current_time_seed());
return ks;
}
void init_random(py::module_& parent_module) {
auto m = parent_module.def_submodule(
"random",
"mlx.core.random: functionality related to random number generation");
m.attr("state") = default_key().state();
m.def(
"seed",
&seed,
[](uint64_t seed) { default_key().seed(seed); },
"seed"_a,
R"pbdoc(
Seed the global PRNG.
@@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return uniform(
to_array(low),
to_array(high),
@@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
std::optional<Dtype> type,
float loc,
float scale,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return normal(shape, type.value_or(float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"loc"_a = 0.0,
@@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
},
@@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto p = to_array(p_);
if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s);
@@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
@@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
"gumbel",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
@@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified.");
@@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
Returns:
array: The ``shape``-sized output array with type ``uint32``.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
}

View File

@@ -135,6 +135,64 @@ py::object tree_map(
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
std::vector<array> flat_tree;
@@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() {
struct PyCompiledFun {
py::function fun;
size_t fun_id;
py::object captured_inputs;
py::object captured_outputs;
size_t num_outputs{0};
PyCompiledFun(const py::function& fun)
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
: fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
captured_inputs(inputs),
captured_outputs(outputs) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
@@ -505,23 +569,61 @@ struct PyCompiledFun {
PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);
num_outputs = other.num_outputs;
};
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function and flatten the outputs
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(this->fun(*tree_unflatten(args, a))), true);
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
std::vector<array> trace_captures;
if (!py::isinstance<py::none>(captured_inputs)) {
flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert(
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
tree_fill(captured_inputs, trace_captures);
}
tree_cache().insert({this->fun_id, py_outputs});
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(fun(*tree_unflatten(args, a))), false);
tree_cache().insert({fun_id, py_outputs});
num_outputs = outputs.size();
if (!py::isinstance<py::none>(captured_outputs)) {
auto flat_out_captures = tree_flatten(captured_outputs, false);
outputs.insert(
outputs.end(),
std::make_move_iterator(flat_out_captures.begin()),
std::make_move_iterator(flat_out_captures.end()));
}
// Replace tracers with originals in captured inputs
if (!py::isinstance<py::none>(captured_inputs)) {
tree_replace(captured_inputs, trace_captures, flat_in_captures);
}
return outputs;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
auto inputs = tree_flatten(args, false);
if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_in_captures.begin()),
std::make_move_iterator(flat_in_captures.end()));
}
// Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) {
std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end()));
tree_fill(captured_outputs, captures);
}
// Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id);
@@ -534,6 +636,8 @@ struct PyCompiledFun {
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
fun.release().dec_ref();
captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref();
}
};
@@ -601,7 +705,7 @@ void init_transforms(py::module_& m) {
m.def(
"eval",
[](const py::args& args) {
std::vector<array> arrays = tree_flatten(args);
std::vector<array> arrays = tree_flatten(args, false);
{
py::gil_scoped_release nogil;
eval(arrays);
@@ -615,8 +719,8 @@ void init_transforms(py::module_& m) {
Args:
*args (arrays or trees of arrays): Each argument can be a single array
or a tree of arrays. If a tree is given the nodes can be a Python
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
an :class:`array`.
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
arrays are ignored.
)pbdoc");
m.def(
"jvp",
@@ -859,10 +963,14 @@ void init_transforms(py::module_& m) {
"file"_a);
m.def(
"compile",
[](const py::function& fun) {
return py::cpp_function(PyCompiledFun{fun});
[](const py::function& fun,
const py::object& inputs,
const py::object& outputs) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
},
"fun"_a,
"inputs"_a = std::nullopt,
"outputs"_a = std::nullopt,
R"pbdoc(
compile(fun: function) -> function
@@ -872,6 +980,16 @@ void init_transforms(py::module_& m) {
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
inputs (list or dict, optional): These inputs will be captured during
the function compilation along with the inputs to ``fun``. The ``inputs``
can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested
lists, dictionaries, or arrays. Leaf nodes that are not
:obj:`array` are ignored. Default: ``None``
outputs (list or dict, optional): These outputs will be captured and
updated in a compiled function. The ``outputs`` can be a
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
Default: ``None``
Returns:
function: A compiled function which has the same input arguments