mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 15:24:57 +08:00
Custom transforms (#1246)
This commit is contained in:

committed by
GitHub

parent
a3c287354f
commit
5c1fa64fb0
@@ -593,7 +593,454 @@ class PyCheckpointedFun {
|
||||
nb::callable fun_;
|
||||
};
|
||||
|
||||
/**
|
||||
* PyCustomFunction is the class that implements the python decorator
|
||||
* `mx.custom_function`.
|
||||
*
|
||||
* It implements a callable that instead of simply calling `fun` it creates a
|
||||
* CustomTransforms primitive via the `custom_function` C++ op which allows us
|
||||
* to redefine the vjp, jvp and vmap transformations.
|
||||
*
|
||||
* The implementation is verbose due to explicit handling of the destruction of
|
||||
* various python objects to make sure that there is no double-free and that
|
||||
* all of them are deleted while under GIL.
|
||||
*
|
||||
* Namely, for every one of the functions passed to the C++ `custom_function`
|
||||
* we create a callable struct that holds the following python objects (when
|
||||
* needed).
|
||||
*
|
||||
* - An nb::callable which holds the passed function or transform
|
||||
* - An nb::object holding input structure, namely the `(args, kwargs)`
|
||||
* passed to the function in order to be able to recreate the arguments
|
||||
* from the input arrays.
|
||||
* - A std::shared_ptr<nb::object> holding the output structure name the
|
||||
* structure of the return value of `fun`. It is a shared_ptr so that it
|
||||
* can be set when the function is called and then used in the `vjp`
|
||||
* transform. We delete the object only when the shared_ptr is about to be
|
||||
* deleted see `output_structure_.use_count() == 1` to make sure that the
|
||||
* object is deleted under GIL.
|
||||
*/
|
||||
class PyCustomFunction {
|
||||
public:
|
||||
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
|
||||
~PyCustomFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
if (vjp_fun_.has_value()) {
|
||||
(*vjp_fun_).release().dec_ref();
|
||||
}
|
||||
if (jvp_fun_.has_value()) {
|
||||
(*jvp_fun_).release().dec_ref();
|
||||
}
|
||||
if (vmap_fun_.has_value()) {
|
||||
(*vmap_fun_).release().dec_ref();
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
nb::callable fun_;
|
||||
nb::object input_structure_;
|
||||
std::shared_ptr<nb::object> output_structure_;
|
||||
|
||||
InnerFunction(
|
||||
nb::callable fun,
|
||||
nb::object input_structure,
|
||||
std::shared_ptr<nb::object> output_structure)
|
||||
: fun_(std::move(fun)),
|
||||
input_structure_(std::move(input_structure)),
|
||||
output_structure_(std::move(output_structure)) {}
|
||||
~InnerFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, inputs));
|
||||
std::vector<array> outputs;
|
||||
std::tie(outputs, *output_structure_) =
|
||||
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
struct InnerVJPFunction {
|
||||
nb::callable vjp_fun_;
|
||||
nb::object input_structure_;
|
||||
std::shared_ptr<nb::object> output_structure_;
|
||||
|
||||
InnerVJPFunction(
|
||||
nb::callable vjp_fun,
|
||||
nb::object input_structure,
|
||||
std::shared_ptr<nb::object> output_structure)
|
||||
: vjp_fun_(std::move(vjp_fun)),
|
||||
input_structure_(std::move(input_structure)),
|
||||
output_structure_(std::move(output_structure)) {}
|
||||
~InnerVJPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vjp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<array>& outputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, primals));
|
||||
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||
auto new_cotangents =
|
||||
tree_unflatten_from_structure(*output_structure_, cotangents);
|
||||
auto new_outputs =
|
||||
tree_unflatten_from_structure(*output_structure_, outputs);
|
||||
|
||||
if (args.size() == 1) {
|
||||
return tree_flatten(
|
||||
vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]),
|
||||
false);
|
||||
} else {
|
||||
return tree_flatten(
|
||||
vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]),
|
||||
false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct InnerJVPFunction {
|
||||
nb::callable jvp_fun_;
|
||||
nb::object input_structure_;
|
||||
|
||||
InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure)
|
||||
: jvp_fun_(std::move(jvp_fun)),
|
||||
input_structure_(std::move(input_structure)) {}
|
||||
~InnerJVPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
jvp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, primals));
|
||||
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||
auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
|
||||
if (kwargs.size() > 0) {
|
||||
throw std::invalid_argument(
|
||||
"[custom jvp] Function should only accept positional arguments");
|
||||
}
|
||||
|
||||
// Make a new pytree which has tangents or None when a tangent is not
|
||||
// available.
|
||||
std::vector<bool> have_tangents(primals.size(), false);
|
||||
for (auto arg : argnums) {
|
||||
have_tangents[arg] = true;
|
||||
}
|
||||
int array_index = 0;
|
||||
int tangent_index = 0;
|
||||
auto new_tangents =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
if (nb::isinstance<array>(element) &&
|
||||
have_tangents[array_index++]) {
|
||||
return nb::cast(tangents[tangent_index++]);
|
||||
} else {
|
||||
return nb::none();
|
||||
}
|
||||
}));
|
||||
|
||||
if (args.size() == 1) {
|
||||
return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false);
|
||||
} else {
|
||||
return tree_flatten(jvp_fun_(args, new_tangents), false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct InnerVmapFunction {
|
||||
nb::callable vmap_fun_;
|
||||
nb::object input_structure_;
|
||||
|
||||
InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure)
|
||||
: vmap_fun_(std::move(vmap_fun)),
|
||||
input_structure_(std::move(input_structure)) {}
|
||||
~InnerVmapFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vmap_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> operator()(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, inputs));
|
||||
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||
auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
|
||||
if (kwargs.size() > 0) {
|
||||
throw std::invalid_argument(
|
||||
"[custom vmap] Function should only accept positional arguments");
|
||||
}
|
||||
|
||||
int arr_index;
|
||||
auto new_axes =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
int axis = axes[arr_index++];
|
||||
if (nb::isinstance<array>(element) && axis >= 0) {
|
||||
return nb::cast(axis);
|
||||
} else {
|
||||
return nb::none();
|
||||
}
|
||||
}));
|
||||
|
||||
nb::object result;
|
||||
if (args.size() == 1) {
|
||||
result = vmap_fun_(args[0], new_axes[0]);
|
||||
} else {
|
||||
result = vmap_fun_(args, new_axes);
|
||||
}
|
||||
|
||||
if (!nb::isinstance<nb::tuple>(result)) {
|
||||
throw std::invalid_argument(
|
||||
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||
}
|
||||
nb::tuple result_tuple = nb::cast<nb::tuple>(result);
|
||||
if (result_tuple.size() != 2) {
|
||||
throw std::invalid_argument(
|
||||
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
std::vector<int> output_axes;
|
||||
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
|
||||
if (nb::isinstance<array>(objects[0])) {
|
||||
outputs.push_back(nb::cast<array>(objects[0]));
|
||||
output_axes.push_back(
|
||||
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
|
||||
}
|
||||
});
|
||||
|
||||
return {outputs, output_axes};
|
||||
}
|
||||
};
|
||||
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
if (!vjp_fun_.has_value() && !jvp_fun_.has_value() &&
|
||||
!vmap_fun_.has_value()) {
|
||||
return fun_(*args, **kwargs);
|
||||
}
|
||||
|
||||
// Extract the inputs and their structure in capturable vars
|
||||
std::vector<array> input_arrays;
|
||||
nb::object input_structure;
|
||||
auto full_args = nb::make_tuple(args, kwargs);
|
||||
std::tie(input_arrays, input_structure) =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
// The output structure will be stored here to be used in the custom vjp
|
||||
// function
|
||||
auto output_structure = std::make_shared<nb::object>();
|
||||
|
||||
// Make a function that calls fun_ in the forward pass and vjp_ in the
|
||||
// backward pass. Then call it immediately and return the results.
|
||||
auto f = custom_function(
|
||||
InnerFunction(fun_, input_structure, output_structure),
|
||||
make_vjp_function(input_structure, output_structure),
|
||||
make_jvp_function(input_structure),
|
||||
make_vmap_function(input_structure));
|
||||
|
||||
auto outputs = f(input_arrays);
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
}
|
||||
|
||||
PyCustomFunction& set_vjp(nb::callable vjp_fun) {
|
||||
vjp_fun_ = vjp_fun;
|
||||
return *this;
|
||||
}
|
||||
|
||||
PyCustomFunction& set_jvp(nb::callable jvp_fun) {
|
||||
jvp_fun_ = jvp_fun;
|
||||
return *this;
|
||||
}
|
||||
|
||||
PyCustomFunction& set_vmap(nb::callable vmap_fun) {
|
||||
vmap_fun_ = vmap_fun;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<InnerVJPFunction> make_vjp_function(
|
||||
nb::object input_structure,
|
||||
std::shared_ptr<nb::object> output_structure) {
|
||||
if (!vjp_fun_.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return InnerVJPFunction(*vjp_fun_, input_structure, output_structure);
|
||||
}
|
||||
|
||||
std::optional<InnerJVPFunction> make_jvp_function(
|
||||
nb::object input_structure) {
|
||||
if (!jvp_fun_.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return InnerJVPFunction(*jvp_fun_, input_structure);
|
||||
}
|
||||
|
||||
std::optional<InnerVmapFunction> make_vmap_function(
|
||||
nb::object input_structure) {
|
||||
if (!vmap_fun_.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return InnerVmapFunction(*vmap_fun_, input_structure);
|
||||
}
|
||||
|
||||
nb::callable fun_;
|
||||
std::optional<nb::callable> vjp_fun_;
|
||||
std::optional<nb::callable> jvp_fun_;
|
||||
std::optional<nb::callable> vmap_fun_;
|
||||
};
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
nb::class_<PyCustomFunction>(
|
||||
m,
|
||||
"custom_function",
|
||||
R"pbdoc(
|
||||
Set up a function for custom gradient and vmap definitions.
|
||||
|
||||
This class is meant to be used as a function decorator. Instances are
|
||||
callables that behave identically to the wrapped function. However, when
|
||||
a function transformation is used (e.g. computing gradients using
|
||||
:func:`value_and_grad`) then the functions defined via :method:`vjp`,
|
||||
:method:`jvp` and :method:`vmap` are used instead of the default
|
||||
transformation.
|
||||
|
||||
Note, all custom transformations are optional. Undefined transformations
|
||||
fall back to the default behaviour.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
"f"_a,
|
||||
nb::sig("def __init__(self, f: callable)"))
|
||||
.def("__call__", &PyCustomFunction::call_impl)
|
||||
.def(
|
||||
"vjp",
|
||||
&PyCustomFunction::set_vjp,
|
||||
"f"_a,
|
||||
nb::sig("def vjp(self, f_vjp: callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vjp for the wrapped function.
|
||||
|
||||
The vjp function takes three arguments:
|
||||
|
||||
- *primals*: A pytree that contains all the positional arguments to
|
||||
the function. It could be a single array, a tuple of arrays or a
|
||||
full blown tuple of dicts of arrays etc.
|
||||
- *cotangents*: A pytree that matches the structure of the output
|
||||
but contains the cotangents (usually the gradients of the loss
|
||||
function with respect to the outputs).
|
||||
- *outputs*: The outputs of the function to be used to avoid
|
||||
recomputing them for the gradient computation.
|
||||
|
||||
The vjp function should return the same pytree structure as the
|
||||
primals but containing the corresponding computed cotangents.
|
||||
)pbdoc")
|
||||
.def(
|
||||
"jvp",
|
||||
&PyCustomFunction::set_jvp,
|
||||
"f"_a,
|
||||
nb::sig("def jvp(self, f_jvp: callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom jvp for the wrapped function.
|
||||
|
||||
The jvp function takes two arguments:
|
||||
|
||||
- *primals*: A pytree that contains all the positional arguments to
|
||||
the function. It could be a single array, a tuple of arrays or a
|
||||
full blown tuple of dicts of arrays etc.
|
||||
- *tangents*: A pytree that matches the structure of the inputs but
|
||||
instead contains the gradients wrt to each input. Tangents could
|
||||
be ``None`` if some inputs don't have an associated gradient.
|
||||
|
||||
The jvp function should return the same pytree structure as the
|
||||
outputs of the function but containing the tangents.
|
||||
)pbdoc")
|
||||
.def(
|
||||
"vmap",
|
||||
&PyCustomFunction::set_vmap,
|
||||
"f"_a,
|
||||
nb::sig("def vmap(self, f_vmap: callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vectorization transformation for the wrapped function.
|
||||
|
||||
The vmap function takes two arguments:
|
||||
|
||||
- *inputs*: A pytree that contains all the positional arguments to
|
||||
the function. It could be a single array, a tuple of arrays or a
|
||||
full blown tuple of dicts of arrays etc.
|
||||
- *axes*: A pytree that matches the structure of the inputs but
|
||||
instead contains the vectorization axis for each input or
|
||||
``None`` if an input is not vectorized.
|
||||
|
||||
The vmap function should return the outputs of the original
|
||||
function but vectorized over the provided axes. It should also
|
||||
return a pytree with the vectorization axes of each output. If some
|
||||
outputs are no longer vectorized, then their vectorization axis
|
||||
should be ``None``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"eval",
|
||||
[](const nb::args& args) {
|
||||
@@ -888,8 +1335,10 @@ void init_transforms(nb::module_& m) {
|
||||
const nb::object& outputs,
|
||||
bool shapeless) {
|
||||
// Try to get the name
|
||||
auto n = fun.attr("__name__");
|
||||
auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
|
||||
auto n =
|
||||
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
|
||||
auto name = n.is_none() ? "compiled"
|
||||
: nb::cast<std::string>(fun.attr("__name__"));
|
||||
|
||||
// Try to get the signature
|
||||
std::ostringstream sig;
|
||||
|
Reference in New Issue
Block a user