mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
@@ -13,13 +18,17 @@
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||
|
||||
inline std::string type_name_str(const nb::handle& o) {
|
||||
return nb::cast<std::string>(nb::type_name(o.type()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||
std::vector<T> vals;
|
||||
@@ -49,7 +58,7 @@ auto validate_argnums_argnames(
|
||||
}
|
||||
|
||||
auto py_value_and_grad(
|
||||
const py::function& fun,
|
||||
const nb::callable& fun,
|
||||
std::vector<int> argnums,
|
||||
std::vector<std::string> argnames,
|
||||
const std::string& error_msg_tag,
|
||||
@@ -71,7 +80,7 @@ auto py_value_and_grad(
|
||||
}
|
||||
|
||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||
const py::args& args, const py::kwargs& kwargs) {
|
||||
const nb::args& args, const nb::kwargs& kwargs) {
|
||||
// Sanitize the input
|
||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||
std::ostringstream msg;
|
||||
@@ -89,7 +98,7 @@ auto py_value_and_grad(
|
||||
<< "' because the function is called with the "
|
||||
<< "following keyword arguments {";
|
||||
for (auto item : kwargs) {
|
||||
msg << item.first.cast<std::string>() << ",";
|
||||
msg << nb::cast<std::string>(item.first) << ",";
|
||||
}
|
||||
msg << "}";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -115,7 +124,7 @@ auto py_value_and_grad(
|
||||
|
||||
// value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_value_out;
|
||||
nb::object py_value_out;
|
||||
auto value_and_grads = value_and_grad(
|
||||
[&fun,
|
||||
&args,
|
||||
@@ -127,15 +136,15 @@ auto py_value_and_grad(
|
||||
&error_msg_tag,
|
||||
scalar_func_only](const std::vector<array>& a) {
|
||||
// Copy the arguments
|
||||
py::args args_cpy = py::tuple(args.size());
|
||||
py::kwargs kwargs_cpy = py::kwargs();
|
||||
nb::list args_cpy;
|
||||
nb::kwargs kwargs_cpy = nb::kwargs();
|
||||
int j = 0;
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
args_cpy[i] = tree_unflatten(args[i], a, counts[j]);
|
||||
args_cpy.append(tree_unflatten(args[i], a, counts[j]));
|
||||
j++;
|
||||
} else {
|
||||
args_cpy[i] = args[i];
|
||||
args_cpy.append(args[i]);
|
||||
}
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
@@ -154,25 +163,25 @@ auto py_value_and_grad(
|
||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!py::isinstance<array>(py_value_out)) {
|
||||
if (!nb::isinstance<array>(py_value_out)) {
|
||||
if (scalar_func_only) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be a "
|
||||
<< "scalar array; but " << py_value_out.get_type()
|
||||
<< "scalar array; but " << type_name_str(py_value_out)
|
||||
<< " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<py::tuple>(py_value_out)) {
|
||||
if (!nb::isinstance<nb::tuple>(py_value_out)) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
|
||||
<< py_value_out.get_type() << " was returned.";
|
||||
<< type_name_str(py_value_out) << " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
py::tuple ret = py::cast<py::tuple>(py_value_out);
|
||||
nb::tuple ret = nb::cast<nb::tuple>(py_value_out);
|
||||
if (ret.size() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
@@ -182,14 +191,14 @@ auto py_value_and_grad(
|
||||
<< "we got an empty tuple.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<array>(ret[0])) {
|
||||
if (!nb::isinstance<array>(ret[0])) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
|
||||
<< "was a tuple with the first value being of type "
|
||||
<< ret[0].get_type() << " .";
|
||||
<< type_name_str(ret[0]) << " .";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -212,61 +221,60 @@ auto py_value_and_grad(
|
||||
// In case 2 we return a tuple of the above.
|
||||
// In case 3 we return a tuple containing a tuple and dict (sth like
|
||||
// (tuple(), dict(x=mx.array(5))) ).
|
||||
py::object positional_grads;
|
||||
py::object keyword_grads;
|
||||
py::object py_grads;
|
||||
nb::object positional_grads;
|
||||
nb::object keyword_grads;
|
||||
nb::object py_grads;
|
||||
|
||||
// Collect the gradients for the positional arguments
|
||||
if (argnums.size() == 1) {
|
||||
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
||||
} else if (argnums.size() > 1) {
|
||||
py::tuple grads_(argnums.size());
|
||||
nb::list grads_;
|
||||
for (int i = 0; i < argnums.size(); i++) {
|
||||
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]);
|
||||
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
|
||||
}
|
||||
positional_grads = py::cast<py::object>(grads_);
|
||||
positional_grads = nb::tuple(grads_);
|
||||
} else {
|
||||
positional_grads = py::none();
|
||||
positional_grads = nb::none();
|
||||
}
|
||||
|
||||
// No keyword argument gradients so return the tuple of gradients
|
||||
if (argnames.size() == 0) {
|
||||
py_grads = positional_grads;
|
||||
} else {
|
||||
py::dict grads_;
|
||||
nb::dict grads_;
|
||||
for (int i = 0; i < argnames.size(); i++) {
|
||||
auto& k = argnames[i];
|
||||
grads_[k.c_str()] = tree_unflatten(
|
||||
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
|
||||
}
|
||||
keyword_grads = py::cast<py::object>(grads_);
|
||||
keyword_grads = grads_;
|
||||
|
||||
py_grads =
|
||||
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
|
||||
py_grads = nb::make_tuple(positional_grads, keyword_grads);
|
||||
}
|
||||
|
||||
// Put the values back in the container
|
||||
py::object return_value = tree_unflatten(py_value_out, value);
|
||||
nb::object return_value = tree_unflatten(py_value_out, value);
|
||||
return std::make_pair(return_value, py_grads);
|
||||
};
|
||||
}
|
||||
|
||||
auto py_vmap(
|
||||
const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const py::args& args) {
|
||||
auto axes_to_flat_tree = [](const py::object& tree,
|
||||
const py::object& axes) {
|
||||
const nb::callable& fun,
|
||||
const nb::object& in_axes,
|
||||
const nb::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const nb::args& args) {
|
||||
auto axes_to_flat_tree = [](const nb::object& tree,
|
||||
const nb::object& axes) {
|
||||
auto tree_axes = tree_map(
|
||||
{tree, axes},
|
||||
[](const std::vector<py::object>& inputs) { return inputs[1]; });
|
||||
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
|
||||
std::vector<int> flat_axes;
|
||||
tree_visit(tree_axes, [&flat_axes](py::handle obj) {
|
||||
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
|
||||
if (obj.is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj)));
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
@@ -280,7 +288,7 @@ auto py_vmap(
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_outputs;
|
||||
nb::object py_outputs;
|
||||
|
||||
auto vmap_fn =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
@@ -305,24 +313,24 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
std::unordered_map<size_t, nb::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<size_t, py::object> tree_cache_;
|
||||
static std::unordered_map<size_t, nb::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
nb::callable fun;
|
||||
size_t fun_id;
|
||||
py::object captured_inputs;
|
||||
py::object captured_outputs;
|
||||
nb::object captured_inputs;
|
||||
nb::object captured_outputs;
|
||||
bool shapeless;
|
||||
size_t num_outputs{0};
|
||||
mutable size_t num_outputs{0};
|
||||
|
||||
PyCompiledFun(
|
||||
const py::function& fun,
|
||||
py::object inputs,
|
||||
py::object outputs,
|
||||
const nb::callable& fun,
|
||||
nb::object inputs,
|
||||
nb::object outputs,
|
||||
bool shapeless)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
@@ -342,7 +350,7 @@ struct PyCompiledFun {
|
||||
num_outputs = other.num_outputs;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
// Flat array inputs
|
||||
std::vector<array> inputs;
|
||||
|
||||
@@ -358,45 +366,45 @@ struct PyCompiledFun {
|
||||
constexpr uint64_t dict_identifier = 18446744073709551521UL;
|
||||
|
||||
// Flatten the tree with hashed constants and structure
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle obj) {
|
||||
if (py::isinstance<py::list>(obj)) {
|
||||
auto l = py::cast<py::list>(obj);
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle obj) {
|
||||
if (nb::isinstance<nb::list>(obj)) {
|
||||
auto l = nb::cast<nb::list>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
recurse(l[i]);
|
||||
}
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
auto l = py::cast<py::tuple>(obj);
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
auto l = nb::cast<nb::tuple>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (auto item : obj) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
auto d = py::cast<py::dict>(obj);
|
||||
} else if (nb::isinstance<nb::dict>(obj)) {
|
||||
auto d = nb::cast<nb::dict>(obj);
|
||||
constants.push_back(dict_identifier);
|
||||
for (auto item : d) {
|
||||
auto r = py::hash(item.first);
|
||||
auto r = item.first.attr("__hash__");
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
recurse(item.second);
|
||||
}
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
inputs.push_back(py::cast<array>(obj));
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
inputs.push_back(nb::cast<array>(obj));
|
||||
constants.push_back(array_identifier);
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
auto r = py::hash(obj);
|
||||
} else if (nb::isinstance<nb::str>(obj)) {
|
||||
auto r = obj.attr("__hash__");
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
auto r = obj.cast<int64_t>();
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
auto r = nb::cast<int64_t>(obj);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
auto r = obj.cast<double>();
|
||||
} else if (nb::isinstance<nb::float_>(obj)) {
|
||||
auto r = nb::cast<double>(obj);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Function arguments must be trees of arrays "
|
||||
<< "or constants (floats, ints, or strings), but received "
|
||||
<< "type " << obj.get_type() << ".";
|
||||
<< "type " << type_name_str(obj) << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
};
|
||||
@@ -404,13 +412,12 @@ struct PyCompiledFun {
|
||||
recurse(args);
|
||||
int num_args = inputs.size();
|
||||
recurse(kwargs);
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||
const std::vector<array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
if (!captured_inputs.is_none()) {
|
||||
flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
trace_captures.insert(
|
||||
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
|
||||
@@ -425,7 +432,7 @@ struct PyCompiledFun {
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
|
||||
num_outputs = outputs.size();
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
if (!captured_outputs.is_none()) {
|
||||
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
||||
outputs.insert(
|
||||
outputs.end(),
|
||||
@@ -434,13 +441,13 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Replace tracers with originals in captured inputs
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
if (!captured_inputs.is_none()) {
|
||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||
}
|
||||
return outputs;
|
||||
};
|
||||
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
if (!captured_inputs.is_none()) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
@@ -451,7 +458,7 @@ struct PyCompiledFun {
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
if (!captured_outputs.is_none()) {
|
||||
std::vector<array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
std::make_move_iterator(outputs.end()));
|
||||
@@ -459,12 +466,16 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
nb::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
||||
}
|
||||
|
||||
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
||||
return const_cast<PyCompiledFun*>(this)->call_impl(args, kwargs);
|
||||
};
|
||||
|
||||
~PyCompiledFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
@@ -476,35 +487,35 @@ struct PyCompiledFun {
|
||||
|
||||
class PyCheckpointedFun {
|
||||
public:
|
||||
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
|
||||
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
|
||||
|
||||
~PyCheckpointedFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
py::object fun_;
|
||||
py::object args_structure_;
|
||||
std::weak_ptr<py::object> output_structure_;
|
||||
nb::object fun_;
|
||||
nb::object args_structure_;
|
||||
std::weak_ptr<nb::object> output_structure_;
|
||||
|
||||
InnerFunction(
|
||||
py::object fun,
|
||||
py::object args_structure,
|
||||
std::weak_ptr<py::object> output_structure)
|
||||
nb::object fun,
|
||||
nb::object args_structure,
|
||||
std::weak_ptr<nb::object> output_structure)
|
||||
: fun_(std::move(fun)),
|
||||
args_structure_(std::move(args_structure)),
|
||||
output_structure_(output_structure) {}
|
||||
~InnerFunction() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
args_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
auto args = py::cast<py::tuple>(
|
||||
auto args = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(args_structure_, inputs));
|
||||
auto [outputs, output_structure] =
|
||||
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
|
||||
@@ -515,9 +526,9 @@ class PyCheckpointedFun {
|
||||
}
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto output_structure = std::make_shared<py::object>();
|
||||
auto full_args = py::make_tuple(args, kwargs);
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
auto output_structure = std::make_shared<nb::object>();
|
||||
auto full_args = nb::make_tuple(args, kwargs);
|
||||
auto [inputs, args_structure] =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
@@ -527,26 +538,27 @@ class PyCheckpointedFun {
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
}
|
||||
|
||||
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
||||
return const_cast<PyCheckpointedFun*>(this)->call_impl(args, kwargs);
|
||||
}
|
||||
|
||||
private:
|
||||
py::function fun_;
|
||||
nb::callable fun_;
|
||||
};
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args) {
|
||||
[](const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
nb::gil_scoped_release nogil;
|
||||
eval(arrays);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
nb::sig("def eval(*args) -> None"),
|
||||
R"pbdoc(
|
||||
eval(*args) -> None
|
||||
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
Args:
|
||||
@@ -557,19 +569,15 @@ void init_transforms(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return jvp(vfun, primals, tangents);
|
||||
@@ -577,17 +585,16 @@ void init_transforms(py::module_& m) {
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
nb::sig(
|
||||
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
R"pbdoc(
|
||||
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||
at ``primals`` with the ``tangents``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@@ -601,19 +608,15 @@ void init_transforms(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vjp",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return vjp(vfun, primals, cotangents);
|
||||
@@ -621,16 +624,16 @@ void init_transforms(py::module_& m) {
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
nb::sig(
|
||||
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
R"pbdoc(
|
||||
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||
function ``fun`` evaluated at ``primals``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@@ -644,20 +647,20 @@ void init_transforms(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"value_and_grad",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
return py::cpp_function(py_value_and_grad(
|
||||
return nb::cpp_function(py_value_and_grad(
|
||||
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
R"pbdoc(
|
||||
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
The function passed to :func:`value_and_grad` should return either
|
||||
@@ -688,7 +691,7 @@ void init_transforms(py::module_& m) {
|
||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array` or a tuple the first element
|
||||
of which should be a scalar :class:`array`.
|
||||
@@ -702,34 +705,34 @@ void init_transforms(py::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which returns a tuple where the first element
|
||||
callable: A function which returns a tuple where the first element
|
||||
is the output of `fun` and the second element is the gradients w.r.t.
|
||||
the loss.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"grad",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
auto fn =
|
||||
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
|
||||
return py::cpp_function(
|
||||
[fn](const py::args& args, const py::kwargs& kwargs) {
|
||||
return nb::cpp_function(
|
||||
[fn](const nb::args& args, const nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
R"pbdoc(
|
||||
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
@@ -742,26 +745,26 @@ void init_transforms(py::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which has the same input arguments as ``fun`` and
|
||||
callable: A function which has the same input arguments as ``fun`` and
|
||||
returns the gradient(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vmap",
|
||||
[](const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return py::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
[](const nb::callable& fun,
|
||||
const nb::object& in_axes,
|
||||
const nb::object& out_axes) {
|
||||
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
},
|
||||
"fun"_a,
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
nb::sig(
|
||||
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
|
||||
R"pbdoc(
|
||||
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
|
||||
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or a tree of :class:`array` and returns
|
||||
a variable number of :class:`array` or a tree of :class:`array`.
|
||||
in_axes (int, optional): An integer or a valid prefix tree of the
|
||||
@@ -774,16 +777,16 @@ void init_transforms(py::module_& m) {
|
||||
Defaults to ``0``.
|
||||
|
||||
Returns:
|
||||
function: The vectorized function.
|
||||
callable: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](py::object file, const py::args& args) {
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
std::ofstream out(py::cast<std::string>(file));
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
} else if (py::hasattr(file, "write")) {
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
@@ -793,57 +796,50 @@ void init_transforms(py::module_& m) {
|
||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||
}
|
||||
},
|
||||
"file"_a);
|
||||
"file"_a,
|
||||
"args"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun,
|
||||
const py::object& inputs,
|
||||
const py::object& outputs,
|
||||
[](const nb::callable& fun,
|
||||
const nb::object& inputs,
|
||||
const nb::object& outputs,
|
||||
bool shapeless) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
std::ostringstream doc;
|
||||
auto name = fun.attr("__name__").cast<std::string>();
|
||||
doc << name;
|
||||
// Try to get the name
|
||||
auto n = fun.attr("__name__");
|
||||
auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
|
||||
|
||||
// Try to get the signature
|
||||
auto inspect = py::module::import("inspect");
|
||||
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
|
||||
doc << inspect.attr("signature")(fun)
|
||||
.attr("__str__")()
|
||||
.cast<std::string>();
|
||||
std::ostringstream sig;
|
||||
sig << "def " << name;
|
||||
auto inspect = nb::module_::import_("inspect");
|
||||
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
|
||||
sig << nb::cast<std::string>(
|
||||
inspect.attr("signature")(fun).attr("__str__")());
|
||||
} else {
|
||||
sig << "(*args, **kwargs)";
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
|
||||
doc << "\n\n";
|
||||
auto dstr = d.cast<std::string>();
|
||||
// Add spaces to match first line indentation with remainder of
|
||||
// docstring
|
||||
int i = 0;
|
||||
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
|
||||
doc << ' ';
|
||||
}
|
||||
doc << dstr;
|
||||
}
|
||||
auto doc_str = doc.str();
|
||||
return py::cpp_function(
|
||||
auto d = inspect.attr("getdoc")(fun);
|
||||
std::string doc =
|
||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||
|
||||
auto sig_str = sig.str();
|
||||
return nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
py::name(name.c_str()),
|
||||
py::doc(doc_str.c_str()));
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str());
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
"outputs"_a = std::nullopt,
|
||||
"inputs"_a = nb::none(),
|
||||
"outputs"_a = nb::none(),
|
||||
"shapeless"_a = false,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
Returns a compiled function which produces the same output as ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
inputs (list or dict, optional): These inputs will be captured during
|
||||
@@ -864,15 +860,13 @@ void init_transforms(py::module_& m) {
|
||||
``shapeless`` set to ``True``. Default: ``False``
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
callable: A compiled function which has the same input arguments
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"disable_compile",
|
||||
&disable_compile,
|
||||
R"pbdoc(
|
||||
disable_compile() -> None
|
||||
|
||||
Globally disable compilation. Setting the environment variable
|
||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||
)pbdoc");
|
||||
@@ -880,17 +874,15 @@ void init_transforms(py::module_& m) {
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
R"pbdoc(
|
||||
enable_compile() -> None
|
||||
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
|
||||
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
|
||||
"fun"_a);
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user