Fix compile with non standard types (#745)

* refactor tree utils

* fix compile + tree code refactor

* Add an extra test

* add a few missing activations to docs

* hash structure

* Encode the full argument structure

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-02-26 19:28:53 -08:00 committed by GitHub
parent 08226ab491
commit fe1dabf272
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 438 additions and 282 deletions

View File

@ -12,13 +12,24 @@ simple functions.
:toctree: _autosummary_functions
:template: nn-module-template.rst
elu
gelu
gelu_approx
gelu_fast_approx
glu
hardswish
leaky_relu
log_sigmoid
log_softmax
mish
prelu
relu
relu6
selu
softshrink
sigmoid
silu
softmax
softplus
softshrink
step
tanh

View File

@ -37,6 +37,7 @@ from mlx.nn.layers.activations import (
relu,
relu6,
selu,
sigmoid,
silu,
softmax,
softplus,

View File

@ -18,7 +18,7 @@ def _make_activation_module(f):
@partial(mx.compile, shapeless=True)
def sigmoid(x):
r"""Applies the element-wise function:
r"""Applies the sigmoid function.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}

View File

@ -14,6 +14,7 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)

View File

@ -11,6 +11,7 @@
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "python/src/trees.h"
namespace py = pybind11;
using namespace py::literals;
@ -30,246 +31,6 @@ std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
return vals;
}
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
@ -582,9 +343,69 @@ struct PyCompiledFun {
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
// Flat array inputs
std::vector<array> inputs;
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
// Compilation constants which includes the tree structure of the arguments
std::vector<uint64_t> constants;
// Reserve some large primes to signify the presence of an array, a list or
// a dict in order to encode the structure of the pytree. We choose primes
// to reduce slightly the chances of these numbers occuring by a
// multiplication as values in the constants list.
constexpr uint64_t array_identifier = 18446744073709551557UL;
constexpr uint64_t list_identifier = 18446744073709551533UL;
constexpr uint64_t dict_identifier = 18446744073709551521UL;
// Flatten the tree with hashed constants and structure
std::function<void(py::handle)> recurse;
recurse = [&](py::handle obj) {
if (py::isinstance<py::list>(obj)) {
auto l = py::cast<py::list>(obj);
constants.push_back(list_identifier);
for (int i = 0; i < l.size(); ++i) {
recurse(l[i]);
}
} else if (py::isinstance<py::tuple>(obj)) {
auto l = py::cast<py::tuple>(obj);
constants.push_back(list_identifier);
for (auto item : obj) {
recurse(item);
}
} else if (py::isinstance<py::dict>(obj)) {
auto d = py::cast<py::dict>(obj);
constants.push_back(dict_identifier);
for (auto item : d) {
auto r = py::hash(item.first);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
recurse(item.second);
}
} else if (py::isinstance<array>(obj)) {
inputs.push_back(py::cast<array>(obj));
constants.push_back(array_identifier);
} else if (py::isinstance<py::str>(obj)) {
auto r = py::hash(obj);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::int_>(obj)) {
auto r = obj.cast<int64_t>();
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::float_>(obj)) {
auto r = obj.cast<double>();
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else {
std::ostringstream msg;
msg << "[compile] Function arguments must be trees of arrays "
<< "or constants (floats, ints, or strings), but received "
<< "type " << obj.get_type() << ".";
throw std::invalid_argument(msg.str());
}
};
recurse(args);
int num_args = inputs.size();
recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
@ -619,14 +440,6 @@ struct PyCompiledFun {
return outputs;
};
{
auto flat_kwargs = tree_flatten(kwargs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_kwargs.begin()),
std::make_move_iterator(flat_kwargs.end()));
}
if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
@ -635,36 +448,6 @@ struct PyCompiledFun {
std::make_move_iterator(flat_in_captures.end()));
}
// Collect the compilation constants
std::vector<uint64_t> constants;
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
// Consider expanding tuples to their contents including start and end
// ids
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
auto r = py::hash(o);
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::int_>(o)) {
auto r = o.cast<int64_t>();
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::float_>(o)) {
auto r = o.cast<double>();
return *reinterpret_cast<uint64_t*>(&r);
} else {
return std::nullopt;
}
};
for (int i = 0; i < args.size(); i++) {
if (auto h = value_hash(args[i]); h.has_value()) {
constants.push_back(*h);
}
}
for (auto& pair : kwargs) {
if (auto h = value_hash(pair.second); h.has_value()) {
constants.push_back(*value_hash(pair.first));
constants.push_back(*h);
}
}
// Compile and call
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);

243
python/src/trees.cpp Normal file
View File

@ -0,0 +1,243 @@
// Copyright © 2023-2024 Apple Inc.
#include "python/src/trees.h"
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index /* = 0 */) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict /* = true */) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index /* = 0 */) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}

60
python/src/trees.h Normal file
View File

@ -0,0 +1,60 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/array.h"
namespace py = pybind11;
using namespace mlx::core;
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform);
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform);
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor);
/**
* Fill a pytree (recursive dict or list of dict or list) in place with the
* given arrays. */
void tree_fill(py::object& tree, const std::vector<array>& values);
/**
* Replace all the arrays from the src values with the dst values in the
* tree.
*/
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst);
/**
* Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array.
*/
std::vector<array> tree_flatten(py::object tree, bool strict = true);
/**
* Unflatten a tree from a vector of arrays.
*/
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index = 0);
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true);
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0);

View File

@ -539,6 +539,48 @@ class TestCompile(mlx_tests.MLXTestCase):
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
# Test nested constant
@partial(mx.compile)
def fun(x, y):
if y[0][0] == 1:
return x + 1
else:
return x + 2
z = fun(mx.array(1), [[1]])
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), [[0]])
self.assertEqual(z.item(), 3)
@partial(mx.compile)
def fun(x, a, b):
for ai in a:
for bi in b:
x = bi * x + ai
return x
z = fun(mx.array(1), [1, 1], [2])
self.assertEqual(z.item(), 7)
z = fun(mx.array(1), [1], [1, 2])
self.assertEqual(z.item(), 5)
counter = [0]
@partial(mx.compile)
def fun(x, y):
counter[0] += 1
return x + y
z = fun(mx.array(1), 1)
self.assertEqual(z.item(), 2)
z = fun(1, mx.array(1))
self.assertEqual(z.item(), 2)
self.assertEqual(counter[0], 2)
def test_compile_inf(self):
@mx.compile
@ -548,6 +590,21 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(mx.array([0.0]))
self.assertEqual(out.item(), False)
def test_unsupported_input_types(self):
class MyClass:
value = 1
@mx.compile
def fun(x, y):
return x + y.value
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), MyClass())
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), y=MyClass())
if __name__ == "__main__":
unittest.main()