mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix compile with non standard types (#745)
* refactor tree utils * fix compile + tree code refactor * Add an extra test * add a few missing activations to docs * hash structure * Encode the full argument structure --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
08226ab491
commit
fe1dabf272
@ -12,13 +12,24 @@ simple functions.
|
|||||||
:toctree: _autosummary_functions
|
:toctree: _autosummary_functions
|
||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
elu
|
||||||
gelu
|
gelu
|
||||||
gelu_approx
|
gelu_approx
|
||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
|
glu
|
||||||
|
hardswish
|
||||||
|
leaky_relu
|
||||||
|
log_sigmoid
|
||||||
|
log_softmax
|
||||||
mish
|
mish
|
||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
|
relu6
|
||||||
selu
|
selu
|
||||||
softshrink
|
sigmoid
|
||||||
silu
|
silu
|
||||||
|
softmax
|
||||||
|
softplus
|
||||||
|
softshrink
|
||||||
step
|
step
|
||||||
|
tanh
|
||||||
|
@ -37,6 +37,7 @@ from mlx.nn.layers.activations import (
|
|||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
selu,
|
selu,
|
||||||
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
softmax,
|
softmax,
|
||||||
softplus,
|
softplus,
|
||||||
|
@ -18,7 +18,7 @@ def _make_activation_module(f):
|
|||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the sigmoid function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
||||||
|
@ -14,6 +14,7 @@ pybind11_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace py::literals;
|
using namespace py::literals;
|
||||||
@ -30,246 +31,6 @@ std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
|||||||
return vals;
|
return vals;
|
||||||
}
|
}
|
||||||
|
|
||||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
|
||||||
std::function<void(py::handle)> recurse;
|
|
||||||
recurse = [&](py::handle subtree) {
|
|
||||||
if (py::isinstance<py::list>(subtree) ||
|
|
||||||
py::isinstance<py::tuple>(subtree)) {
|
|
||||||
for (auto item : subtree) {
|
|
||||||
recurse(item);
|
|
||||||
}
|
|
||||||
} else if (py::isinstance<py::dict>(subtree)) {
|
|
||||||
for (auto item : py::cast<py::dict>(subtree)) {
|
|
||||||
recurse(item.second);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
visitor(subtree);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
recurse(tree);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename V>
|
|
||||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
|
||||||
int len = py::cast<T>(subtrees[0]).size();
|
|
||||||
for (auto& subtree : subtrees) {
|
|
||||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
|
||||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
py::object tree_map(
|
|
||||||
const std::vector<py::object>& trees,
|
|
||||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
|
||||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
|
||||||
|
|
||||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
|
||||||
if (py::isinstance<py::list>(subtrees[0])) {
|
|
||||||
py::list l;
|
|
||||||
std::vector<py::object> items(subtrees.size());
|
|
||||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
|
||||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
|
||||||
if (py::isinstance<py::list>(subtrees[j])) {
|
|
||||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
|
||||||
} else {
|
|
||||||
items[j] = subtrees[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
l.append(recurse(items));
|
|
||||||
}
|
|
||||||
return py::cast<py::object>(l);
|
|
||||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
|
||||||
// Check the rest of the subtrees
|
|
||||||
std::vector<py::object> items(subtrees.size());
|
|
||||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
|
||||||
py::tuple l(len);
|
|
||||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
|
||||||
for (int i = 0; i < len; ++i) {
|
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
|
||||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
|
||||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
|
||||||
} else {
|
|
||||||
items[j] = subtrees[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
l[i] = recurse(items);
|
|
||||||
}
|
|
||||||
return py::cast<py::object>(l);
|
|
||||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
|
||||||
std::vector<py::object> items(subtrees.size());
|
|
||||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
|
||||||
py::dict d;
|
|
||||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
|
||||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
|
||||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
|
||||||
if (!subdict.contains(item.first)) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
|
||||||
}
|
|
||||||
items[j] = subdict[item.first];
|
|
||||||
} else {
|
|
||||||
items[j] = subtrees[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
d[item.first] = recurse(items);
|
|
||||||
}
|
|
||||||
return py::cast<py::object>(d);
|
|
||||||
} else {
|
|
||||||
return transform(subtrees);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return recurse(trees);
|
|
||||||
}
|
|
||||||
|
|
||||||
py::object tree_map(
|
|
||||||
py::object tree,
|
|
||||||
std::function<py::object(py::handle)> transform) {
|
|
||||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
|
||||||
return transform(inputs[0]);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void tree_visit_update(
|
|
||||||
py::object tree,
|
|
||||||
std::function<py::object(py::handle)> visitor) {
|
|
||||||
std::function<py::object(py::handle)> recurse;
|
|
||||||
recurse = [&](py::handle subtree) {
|
|
||||||
if (py::isinstance<py::list>(subtree)) {
|
|
||||||
auto l = py::cast<py::list>(subtree);
|
|
||||||
for (int i = 0; i < l.size(); ++i) {
|
|
||||||
l[i] = recurse(l[i]);
|
|
||||||
}
|
|
||||||
return py::cast<py::object>(l);
|
|
||||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
|
||||||
for (auto item : subtree) {
|
|
||||||
recurse(item);
|
|
||||||
}
|
|
||||||
return py::cast<py::object>(subtree);
|
|
||||||
} else if (py::isinstance<py::dict>(subtree)) {
|
|
||||||
auto d = py::cast<py::dict>(subtree);
|
|
||||||
for (auto item : d) {
|
|
||||||
d[item.first] = recurse(item.second);
|
|
||||||
}
|
|
||||||
return py::cast<py::object>(d);
|
|
||||||
} else if (py::isinstance<array>(subtree)) {
|
|
||||||
return visitor(subtree);
|
|
||||||
} else {
|
|
||||||
return py::cast<py::object>(subtree);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
recurse(tree);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill a pytree (recursive dict or list of dict or list)
|
|
||||||
// in place with the given arrays
|
|
||||||
// Non dict or list nodes are ignored
|
|
||||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
|
||||||
size_t index = 0;
|
|
||||||
tree_visit_update(
|
|
||||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace all the arrays from the src values with the dst values in the tree
|
|
||||||
void tree_replace(
|
|
||||||
py::object& tree,
|
|
||||||
const std::vector<array>& src,
|
|
||||||
const std::vector<array>& dst) {
|
|
||||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
|
||||||
for (int i = 0; i < src.size(); ++i) {
|
|
||||||
src_to_dst.insert({src[i].id(), dst[i]});
|
|
||||||
}
|
|
||||||
tree_visit_update(tree, [&](py::handle node) {
|
|
||||||
auto arr = py::cast<array>(node);
|
|
||||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
|
||||||
return py::cast(it->second);
|
|
||||||
}
|
|
||||||
return py::cast(arr);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
|
||||||
std::vector<array> flat_tree;
|
|
||||||
|
|
||||||
tree_visit(tree, [&](py::handle obj) {
|
|
||||||
if (py::isinstance<array>(obj)) {
|
|
||||||
flat_tree.push_back(py::cast<array>(obj));
|
|
||||||
} else if (strict) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[tree_flatten] The argument should contain only arrays");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return flat_tree;
|
|
||||||
}
|
|
||||||
|
|
||||||
py::object tree_unflatten(
|
|
||||||
py::object tree,
|
|
||||||
const std::vector<array>& values,
|
|
||||||
int index = 0) {
|
|
||||||
return tree_map(tree, [&](py::handle obj) {
|
|
||||||
if (py::isinstance<array>(obj)) {
|
|
||||||
return py::cast(values[index++]);
|
|
||||||
} else {
|
|
||||||
return py::cast<py::object>(obj);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
py::object structure_sentinel() {
|
|
||||||
static py::object sentinel;
|
|
||||||
|
|
||||||
if (sentinel.ptr() == nullptr) {
|
|
||||||
sentinel = py::capsule(&sentinel);
|
|
||||||
// probably not needed but this should make certain that we won't ever
|
|
||||||
// delete the sentinel
|
|
||||||
sentinel.inc_ref();
|
|
||||||
}
|
|
||||||
|
|
||||||
return sentinel;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
|
||||||
py::object tree,
|
|
||||||
bool strict = true) {
|
|
||||||
auto sentinel = structure_sentinel();
|
|
||||||
std::vector<array> flat_tree;
|
|
||||||
auto structure = tree_map(
|
|
||||||
tree,
|
|
||||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
|
||||||
if (py::isinstance<array>(obj)) {
|
|
||||||
flat_tree.push_back(py::cast<array>(obj));
|
|
||||||
return sentinel;
|
|
||||||
} else if (!strict) {
|
|
||||||
return py::cast<py::object>(obj);
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[tree_flatten] The argument should contain only arrays");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return {flat_tree, structure};
|
|
||||||
}
|
|
||||||
|
|
||||||
py::object tree_unflatten_from_structure(
|
|
||||||
py::object structure,
|
|
||||||
const std::vector<array>& values,
|
|
||||||
int index = 0) {
|
|
||||||
auto sentinel = structure_sentinel();
|
|
||||||
return tree_map(structure, [&](py::handle obj) {
|
|
||||||
if (obj.is(sentinel)) {
|
|
||||||
return py::cast(values[index++]);
|
|
||||||
} else {
|
|
||||||
return py::cast<py::object>(obj);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
auto validate_argnums_argnames(
|
auto validate_argnums_argnames(
|
||||||
const std::optional<IntOrVec>& argnums,
|
const std::optional<IntOrVec>& argnums,
|
||||||
const StrOrVec& argnames) {
|
const StrOrVec& argnames) {
|
||||||
@ -582,9 +343,69 @@ struct PyCompiledFun {
|
|||||||
};
|
};
|
||||||
|
|
||||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||||
auto inputs = tree_flatten(args, false);
|
// Flat array inputs
|
||||||
|
std::vector<array> inputs;
|
||||||
|
|
||||||
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
|
// Compilation constants which includes the tree structure of the arguments
|
||||||
|
std::vector<uint64_t> constants;
|
||||||
|
|
||||||
|
// Reserve some large primes to signify the presence of an array, a list or
|
||||||
|
// a dict in order to encode the structure of the pytree. We choose primes
|
||||||
|
// to reduce slightly the chances of these numbers occuring by a
|
||||||
|
// multiplication as values in the constants list.
|
||||||
|
constexpr uint64_t array_identifier = 18446744073709551557UL;
|
||||||
|
constexpr uint64_t list_identifier = 18446744073709551533UL;
|
||||||
|
constexpr uint64_t dict_identifier = 18446744073709551521UL;
|
||||||
|
|
||||||
|
// Flatten the tree with hashed constants and structure
|
||||||
|
std::function<void(py::handle)> recurse;
|
||||||
|
recurse = [&](py::handle obj) {
|
||||||
|
if (py::isinstance<py::list>(obj)) {
|
||||||
|
auto l = py::cast<py::list>(obj);
|
||||||
|
constants.push_back(list_identifier);
|
||||||
|
for (int i = 0; i < l.size(); ++i) {
|
||||||
|
recurse(l[i]);
|
||||||
|
}
|
||||||
|
} else if (py::isinstance<py::tuple>(obj)) {
|
||||||
|
auto l = py::cast<py::tuple>(obj);
|
||||||
|
constants.push_back(list_identifier);
|
||||||
|
for (auto item : obj) {
|
||||||
|
recurse(item);
|
||||||
|
}
|
||||||
|
} else if (py::isinstance<py::dict>(obj)) {
|
||||||
|
auto d = py::cast<py::dict>(obj);
|
||||||
|
constants.push_back(dict_identifier);
|
||||||
|
for (auto item : d) {
|
||||||
|
auto r = py::hash(item.first);
|
||||||
|
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||||
|
recurse(item.second);
|
||||||
|
}
|
||||||
|
} else if (py::isinstance<array>(obj)) {
|
||||||
|
inputs.push_back(py::cast<array>(obj));
|
||||||
|
constants.push_back(array_identifier);
|
||||||
|
} else if (py::isinstance<py::str>(obj)) {
|
||||||
|
auto r = py::hash(obj);
|
||||||
|
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||||
|
} else if (py::isinstance<py::int_>(obj)) {
|
||||||
|
auto r = obj.cast<int64_t>();
|
||||||
|
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||||
|
} else if (py::isinstance<py::float_>(obj)) {
|
||||||
|
auto r = obj.cast<double>();
|
||||||
|
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[compile] Function arguments must be trees of arrays "
|
||||||
|
<< "or constants (floats, ints, or strings), but received "
|
||||||
|
<< "type " << obj.get_type() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
recurse(args);
|
||||||
|
int num_args = inputs.size();
|
||||||
|
recurse(kwargs);
|
||||||
|
|
||||||
|
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||||
const std::vector<array>& a) {
|
const std::vector<array>& a) {
|
||||||
// Put tracers into captured inputs
|
// Put tracers into captured inputs
|
||||||
std::vector<array> flat_in_captures;
|
std::vector<array> flat_in_captures;
|
||||||
@ -619,14 +440,6 @@ struct PyCompiledFun {
|
|||||||
return outputs;
|
return outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
|
||||||
auto flat_kwargs = tree_flatten(kwargs, false);
|
|
||||||
inputs.insert(
|
|
||||||
inputs.end(),
|
|
||||||
std::make_move_iterator(flat_kwargs.begin()),
|
|
||||||
std::make_move_iterator(flat_kwargs.end()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||||
inputs.insert(
|
inputs.insert(
|
||||||
@ -635,36 +448,6 @@ struct PyCompiledFun {
|
|||||||
std::make_move_iterator(flat_in_captures.end()));
|
std::make_move_iterator(flat_in_captures.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect the compilation constants
|
|
||||||
std::vector<uint64_t> constants;
|
|
||||||
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
|
|
||||||
// Consider expanding tuples to their contents including start and end
|
|
||||||
// ids
|
|
||||||
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
|
|
||||||
auto r = py::hash(o);
|
|
||||||
return *reinterpret_cast<uint64_t*>(&r);
|
|
||||||
} else if (py::isinstance<py::int_>(o)) {
|
|
||||||
auto r = o.cast<int64_t>();
|
|
||||||
return *reinterpret_cast<uint64_t*>(&r);
|
|
||||||
} else if (py::isinstance<py::float_>(o)) {
|
|
||||||
auto r = o.cast<double>();
|
|
||||||
return *reinterpret_cast<uint64_t*>(&r);
|
|
||||||
} else {
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (int i = 0; i < args.size(); i++) {
|
|
||||||
if (auto h = value_hash(args[i]); h.has_value()) {
|
|
||||||
constants.push_back(*h);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto& pair : kwargs) {
|
|
||||||
if (auto h = value_hash(pair.second); h.has_value()) {
|
|
||||||
constants.push_back(*value_hash(pair.first));
|
|
||||||
constants.push_back(*h);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile and call
|
// Compile and call
|
||||||
auto outputs =
|
auto outputs =
|
||||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||||
|
243
python/src/trees.cpp
Normal file
243
python/src/trees.cpp
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
|
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||||
|
std::function<void(py::handle)> recurse;
|
||||||
|
recurse = [&](py::handle subtree) {
|
||||||
|
if (py::isinstance<py::list>(subtree) ||
|
||||||
|
py::isinstance<py::tuple>(subtree)) {
|
||||||
|
for (auto item : subtree) {
|
||||||
|
recurse(item);
|
||||||
|
}
|
||||||
|
} else if (py::isinstance<py::dict>(subtree)) {
|
||||||
|
for (auto item : py::cast<py::dict>(subtree)) {
|
||||||
|
recurse(item.second);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
visitor(subtree);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
recurse(tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename V>
|
||||||
|
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||||
|
int len = py::cast<T>(subtrees[0]).size();
|
||||||
|
for (auto& subtree : subtrees) {
|
||||||
|
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||||
|
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object tree_map(
|
||||||
|
const std::vector<py::object>& trees,
|
||||||
|
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||||
|
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||||
|
|
||||||
|
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||||
|
if (py::isinstance<py::list>(subtrees[0])) {
|
||||||
|
py::list l;
|
||||||
|
std::vector<py::object> items(subtrees.size());
|
||||||
|
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||||
|
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (py::isinstance<py::list>(subtrees[j])) {
|
||||||
|
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.append(recurse(items));
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(l);
|
||||||
|
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||||
|
// Check the rest of the subtrees
|
||||||
|
std::vector<py::object> items(subtrees.size());
|
||||||
|
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||||
|
py::tuple l(len);
|
||||||
|
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||||
|
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l[i] = recurse(items);
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(l);
|
||||||
|
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||||
|
std::vector<py::object> items(subtrees.size());
|
||||||
|
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||||
|
py::dict d;
|
||||||
|
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||||
|
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||||
|
if (!subdict.contains(item.first)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||||
|
}
|
||||||
|
items[j] = subdict[item.first];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d[item.first] = recurse(items);
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(d);
|
||||||
|
} else {
|
||||||
|
return transform(subtrees);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return recurse(trees);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object tree_map(
|
||||||
|
py::object tree,
|
||||||
|
std::function<py::object(py::handle)> transform) {
|
||||||
|
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||||
|
return transform(inputs[0]);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void tree_visit_update(
|
||||||
|
py::object tree,
|
||||||
|
std::function<py::object(py::handle)> visitor) {
|
||||||
|
std::function<py::object(py::handle)> recurse;
|
||||||
|
recurse = [&](py::handle subtree) {
|
||||||
|
if (py::isinstance<py::list>(subtree)) {
|
||||||
|
auto l = py::cast<py::list>(subtree);
|
||||||
|
for (int i = 0; i < l.size(); ++i) {
|
||||||
|
l[i] = recurse(l[i]);
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(l);
|
||||||
|
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||||
|
for (auto item : subtree) {
|
||||||
|
recurse(item);
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(subtree);
|
||||||
|
} else if (py::isinstance<py::dict>(subtree)) {
|
||||||
|
auto d = py::cast<py::dict>(subtree);
|
||||||
|
for (auto item : d) {
|
||||||
|
d[item.first] = recurse(item.second);
|
||||||
|
}
|
||||||
|
return py::cast<py::object>(d);
|
||||||
|
} else if (py::isinstance<array>(subtree)) {
|
||||||
|
return visitor(subtree);
|
||||||
|
} else {
|
||||||
|
return py::cast<py::object>(subtree);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
recurse(tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill a pytree (recursive dict or list of dict or list)
|
||||||
|
// in place with the given arrays
|
||||||
|
// Non dict or list nodes are ignored
|
||||||
|
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||||
|
size_t index = 0;
|
||||||
|
tree_visit_update(
|
||||||
|
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace all the arrays from the src values with the dst values in the tree
|
||||||
|
void tree_replace(
|
||||||
|
py::object& tree,
|
||||||
|
const std::vector<array>& src,
|
||||||
|
const std::vector<array>& dst) {
|
||||||
|
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||||
|
for (int i = 0; i < src.size(); ++i) {
|
||||||
|
src_to_dst.insert({src[i].id(), dst[i]});
|
||||||
|
}
|
||||||
|
tree_visit_update(tree, [&](py::handle node) {
|
||||||
|
auto arr = py::cast<array>(node);
|
||||||
|
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||||
|
return py::cast(it->second);
|
||||||
|
}
|
||||||
|
return py::cast(arr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
|
||||||
|
std::vector<array> flat_tree;
|
||||||
|
|
||||||
|
tree_visit(tree, [&](py::handle obj) {
|
||||||
|
if (py::isinstance<array>(obj)) {
|
||||||
|
flat_tree.push_back(py::cast<array>(obj));
|
||||||
|
} else if (strict) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_flatten] The argument should contain only arrays");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return flat_tree;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object tree_unflatten(
|
||||||
|
py::object tree,
|
||||||
|
const std::vector<array>& values,
|
||||||
|
int index /* = 0 */) {
|
||||||
|
return tree_map(tree, [&](py::handle obj) {
|
||||||
|
if (py::isinstance<array>(obj)) {
|
||||||
|
return py::cast(values[index++]);
|
||||||
|
} else {
|
||||||
|
return py::cast<py::object>(obj);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object structure_sentinel() {
|
||||||
|
static py::object sentinel;
|
||||||
|
|
||||||
|
if (sentinel.ptr() == nullptr) {
|
||||||
|
sentinel = py::capsule(&sentinel);
|
||||||
|
// probably not needed but this should make certain that we won't ever
|
||||||
|
// delete the sentinel
|
||||||
|
sentinel.inc_ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
return sentinel;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||||
|
py::object tree,
|
||||||
|
bool strict /* = true */) {
|
||||||
|
auto sentinel = structure_sentinel();
|
||||||
|
std::vector<array> flat_tree;
|
||||||
|
auto structure = tree_map(
|
||||||
|
tree,
|
||||||
|
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||||
|
if (py::isinstance<array>(obj)) {
|
||||||
|
flat_tree.push_back(py::cast<array>(obj));
|
||||||
|
return sentinel;
|
||||||
|
} else if (!strict) {
|
||||||
|
return py::cast<py::object>(obj);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_flatten] The argument should contain only arrays");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return {flat_tree, structure};
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object tree_unflatten_from_structure(
|
||||||
|
py::object structure,
|
||||||
|
const std::vector<array>& values,
|
||||||
|
int index /* = 0 */) {
|
||||||
|
auto sentinel = structure_sentinel();
|
||||||
|
return tree_map(structure, [&](py::handle obj) {
|
||||||
|
if (obj.is(sentinel)) {
|
||||||
|
return py::cast(values[index++]);
|
||||||
|
} else {
|
||||||
|
return py::cast<py::object>(obj);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
60
python/src/trees.h
Normal file
60
python/src/trees.h
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
|
||||||
|
|
||||||
|
py::object tree_map(
|
||||||
|
const std::vector<py::object>& trees,
|
||||||
|
std::function<py::object(const std::vector<py::object>&)> transform);
|
||||||
|
|
||||||
|
py::object tree_map(
|
||||||
|
py::object tree,
|
||||||
|
std::function<py::object(py::handle)> transform);
|
||||||
|
|
||||||
|
void tree_visit_update(
|
||||||
|
py::object tree,
|
||||||
|
std::function<py::object(py::handle)> visitor);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fill a pytree (recursive dict or list of dict or list) in place with the
|
||||||
|
* given arrays. */
|
||||||
|
void tree_fill(py::object& tree, const std::vector<array>& values);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replace all the arrays from the src values with the dst values in the
|
||||||
|
* tree.
|
||||||
|
*/
|
||||||
|
void tree_replace(
|
||||||
|
py::object& tree,
|
||||||
|
const std::vector<array>& src,
|
||||||
|
const std::vector<array>& dst);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||||
|
* function will throw if the tree contains a leaf which is not an array.
|
||||||
|
*/
|
||||||
|
std::vector<array> tree_flatten(py::object tree, bool strict = true);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unflatten a tree from a vector of arrays.
|
||||||
|
*/
|
||||||
|
py::object tree_unflatten(
|
||||||
|
py::object tree,
|
||||||
|
const std::vector<array>& values,
|
||||||
|
int index = 0);
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||||
|
py::object tree,
|
||||||
|
bool strict = true);
|
||||||
|
|
||||||
|
py::object tree_unflatten_from_structure(
|
||||||
|
py::object structure,
|
||||||
|
const std::vector<array>& values,
|
||||||
|
int index = 0);
|
@ -539,6 +539,48 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
z = fun(mx.array(1), "two")
|
z = fun(mx.array(1), "two")
|
||||||
self.assertEqual(z.item(), 3)
|
self.assertEqual(z.item(), 3)
|
||||||
|
|
||||||
|
# Test nested constant
|
||||||
|
@partial(mx.compile)
|
||||||
|
def fun(x, y):
|
||||||
|
if y[0][0] == 1:
|
||||||
|
return x + 1
|
||||||
|
else:
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
z = fun(mx.array(1), [[1]])
|
||||||
|
self.assertEqual(z.item(), 2)
|
||||||
|
|
||||||
|
z = fun(mx.array(1), [[0]])
|
||||||
|
self.assertEqual(z.item(), 3)
|
||||||
|
|
||||||
|
@partial(mx.compile)
|
||||||
|
def fun(x, a, b):
|
||||||
|
for ai in a:
|
||||||
|
for bi in b:
|
||||||
|
x = bi * x + ai
|
||||||
|
return x
|
||||||
|
|
||||||
|
z = fun(mx.array(1), [1, 1], [2])
|
||||||
|
self.assertEqual(z.item(), 7)
|
||||||
|
|
||||||
|
z = fun(mx.array(1), [1], [1, 2])
|
||||||
|
self.assertEqual(z.item(), 5)
|
||||||
|
|
||||||
|
counter = [0]
|
||||||
|
|
||||||
|
@partial(mx.compile)
|
||||||
|
def fun(x, y):
|
||||||
|
counter[0] += 1
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
z = fun(mx.array(1), 1)
|
||||||
|
self.assertEqual(z.item(), 2)
|
||||||
|
|
||||||
|
z = fun(1, mx.array(1))
|
||||||
|
self.assertEqual(z.item(), 2)
|
||||||
|
|
||||||
|
self.assertEqual(counter[0], 2)
|
||||||
|
|
||||||
def test_compile_inf(self):
|
def test_compile_inf(self):
|
||||||
|
|
||||||
@mx.compile
|
@mx.compile
|
||||||
@ -548,6 +590,21 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = fun(mx.array([0.0]))
|
out = fun(mx.array([0.0]))
|
||||||
self.assertEqual(out.item(), False)
|
self.assertEqual(out.item(), False)
|
||||||
|
|
||||||
|
def test_unsupported_input_types(self):
|
||||||
|
|
||||||
|
class MyClass:
|
||||||
|
value = 1
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y.value
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out = fun(mx.array(0.0), MyClass())
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out = fun(mx.array(0.0), y=MyClass())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user