mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <chrono>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
@@ -13,13 +14,55 @@ using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
class PyKeySequence {
|
||||
public:
|
||||
explicit PyKeySequence(uint64_t seed) {
|
||||
state_.append(key(seed));
|
||||
}
|
||||
|
||||
void seed(uint64_t seed) {
|
||||
state_[0] = key(seed);
|
||||
}
|
||||
|
||||
array next() {
|
||||
auto out = split(py::cast<array>(state_[0]));
|
||||
state_[0] = out.first;
|
||||
return out.second;
|
||||
}
|
||||
|
||||
py::list state() {
|
||||
return state_;
|
||||
}
|
||||
|
||||
void release() {
|
||||
py::gil_scoped_acquire gil;
|
||||
state_.release().dec_ref();
|
||||
}
|
||||
|
||||
private:
|
||||
py::list state_;
|
||||
};
|
||||
|
||||
PyKeySequence& default_key() {
|
||||
auto get_current_time_seed = []() {
|
||||
auto now = std::chrono::system_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
now.time_since_epoch())
|
||||
.count();
|
||||
};
|
||||
static PyKeySequence ks(get_current_time_seed());
|
||||
return ks;
|
||||
}
|
||||
|
||||
void init_random(py::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"random",
|
||||
"mlx.core.random: functionality related to random number generation");
|
||||
|
||||
m.attr("state") = default_key().state();
|
||||
m.def(
|
||||
"seed",
|
||||
&seed,
|
||||
[](uint64_t seed) { default_key().seed(seed); },
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Seed the global PRNG.
|
||||
@@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return uniform(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
@@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
|
||||
std::optional<Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
||||
},
|
||||
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"loc"_a = 0.0,
|
||||
@@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return randint(
|
||||
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
||||
},
|
||||
@@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto p = to_array(p_);
|
||||
if (shape.has_value()) {
|
||||
return bernoulli(p, shape.value(), key, s);
|
||||
@@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
auto t = type.value_or(float32);
|
||||
@@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
@@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (shape.has_value() && num_samples.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[categorical] At most one of shape or num_samples can be specified.");
|
||||
@@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
|
||||
Returns:
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
)pbdoc");
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
|
||||
}
|
||||
|
@@ -135,6 +135,64 @@ py::object tree_map(
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor) {
|
||||
std::function<py::object(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree)) {
|
||||
auto l = py::cast<py::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
return py::cast<py::object>(subtree);
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
auto d = py::cast<py::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else if (py::isinstance<array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return py::cast<py::object>(subtree);
|
||||
}
|
||||
};
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](py::handle node) {
|
||||
auto arr = py::cast<array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return py::cast(it->second);
|
||||
}
|
||||
return py::cast(arr);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
@@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
size_t fun_id;
|
||||
py::object captured_inputs;
|
||||
py::object captured_outputs;
|
||||
size_t num_outputs{0};
|
||||
|
||||
PyCompiledFun(const py::function& fun)
|
||||
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
|
||||
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
captured_inputs(inputs),
|
||||
captured_outputs(outputs) {}
|
||||
|
||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
@@ -505,23 +569,61 @@ struct PyCompiledFun {
|
||||
PyCompiledFun(PyCompiledFun&& other)
|
||||
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||
other.fun_id = 0;
|
||||
captured_inputs = std::move(other.captured_inputs);
|
||||
captured_outputs = std::move(other.captured_outputs);
|
||||
num_outputs = other.num_outputs;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
// Call the python function and flatten the outputs
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(this->fun(*tree_unflatten(args, a))), true);
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
trace_captures.insert(
|
||||
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
|
||||
tree_fill(captured_inputs, trace_captures);
|
||||
}
|
||||
|
||||
tree_cache().insert({this->fun_id, py_outputs});
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(fun(*tree_unflatten(args, a))), false);
|
||||
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
|
||||
num_outputs = outputs.size();
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
||||
outputs.insert(
|
||||
outputs.end(),
|
||||
std::make_move_iterator(flat_out_captures.begin()),
|
||||
std::make_move_iterator(flat_out_captures.end()));
|
||||
}
|
||||
|
||||
// Replace tracers with originals in captured inputs
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||
}
|
||||
return outputs;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
auto inputs = tree_flatten(args, false);
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
std::make_move_iterator(flat_in_captures.begin()),
|
||||
std::make_move_iterator(flat_in_captures.end()));
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
std::vector<array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
std::make_move_iterator(outputs.end()));
|
||||
tree_fill(captured_outputs, captures);
|
||||
}
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
@@ -534,6 +636,8 @@ struct PyCompiledFun {
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
captured_inputs.release().dec_ref();
|
||||
captured_outputs.release().dec_ref();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -601,7 +705,7 @@ void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
eval(arrays);
|
||||
@@ -615,8 +719,8 @@ void init_transforms(py::module_& m) {
|
||||
Args:
|
||||
*args (arrays or trees of arrays): Each argument can be a single array
|
||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||
an :class:`array`.
|
||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||
arrays are ignored.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
@@ -859,10 +963,14 @@ void init_transforms(py::module_& m) {
|
||||
"file"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun) {
|
||||
return py::cpp_function(PyCompiledFun{fun});
|
||||
[](const py::function& fun,
|
||||
const py::object& inputs,
|
||||
const py::object& outputs) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
"outputs"_a = std::nullopt,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
@@ -872,6 +980,16 @@ void init_transforms(py::module_& m) {
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
inputs (list or dict, optional): These inputs will be captured during
|
||||
the function compilation along with the inputs to ``fun``. The ``inputs``
|
||||
can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested
|
||||
lists, dictionaries, or arrays. Leaf nodes that are not
|
||||
:obj:`array` are ignored. Default: ``None``
|
||||
outputs (list or dict, optional): These outputs will be captured and
|
||||
updated in a compiled function. The ``outputs`` can be a
|
||||
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
|
||||
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
|
Reference in New Issue
Block a user