diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index fc99dcad1..db276afdf 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -12,13 +12,24 @@ simple functions. :toctree: _autosummary_functions :template: nn-module-template.rst + elu gelu gelu_approx gelu_fast_approx + glu + hardswish + leaky_relu + log_sigmoid + log_softmax mish prelu relu + relu6 selu - softshrink + sigmoid silu + softmax + softplus + softshrink step + tanh diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index d992b0426..6d286220a 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -37,6 +37,7 @@ from mlx.nn.layers.activations import ( relu, relu6, selu, + sigmoid, silu, softmax, softplus, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index dfd435cfd..178cbc7b1 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -18,7 +18,7 @@ def _make_activation_module(f): @partial(mx.compile, shapeless=True) def sigmoid(x): - r"""Applies the element-wise function: + r"""Applies the sigmoid function. .. math:: \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 4df503a4a..7a3729436 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -14,6 +14,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index cda1d6316..12c067be6 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -11,6 +11,7 @@ #include "mlx/graph_utils.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" +#include "python/src/trees.h" namespace py = pybind11; using namespace py::literals; @@ -30,246 +31,6 @@ std::vector to_vector(const std::variant>& v) { return vals; } -void tree_visit(py::object tree, std::function visitor) { - std::function recurse; - recurse = [&](py::handle subtree) { - if (py::isinstance(subtree) || - py::isinstance(subtree)) { - for (auto item : subtree) { - recurse(item); - } - } else if (py::isinstance(subtree)) { - for (auto item : py::cast(subtree)) { - recurse(item.second); - } - } else { - visitor(subtree); - } - }; - - recurse(tree); -} - -template -void validate_subtrees(const std::vector& subtrees) { - int len = py::cast(subtrees[0]).size(); - for (auto& subtree : subtrees) { - if ((py::isinstance(subtree) && py::cast(subtree).size() != len) || - py::isinstance(subtree) || py::isinstance(subtree)) { - throw std::invalid_argument( - "[tree_map] Additional input tree is not a valid prefix of the first tree."); - } - } -} - -py::object tree_map( - const std::vector& trees, - std::function&)> transform) { - std::function&)> recurse; - - recurse = [&](const std::vector& subtrees) { - if (py::isinstance(subtrees[0])) { - py::list l; - std::vector items(subtrees.size()); - validate_subtrees(subtrees); - for (int i = 0; i < py::cast(subtrees[0]).size(); ++i) { - for (int j = 0; j < subtrees.size(); ++j) { - if (py::isinstance(subtrees[j])) { - items[j] = py::cast(subtrees[j])[i]; - } else { - items[j] = subtrees[j]; - } - } - l.append(recurse(items)); - } - return py::cast(l); - } else if (py::isinstance(subtrees[0])) { - // Check the rest of the subtrees - std::vector items(subtrees.size()); - int len = py::cast(subtrees[0]).size(); - py::tuple l(len); - validate_subtrees(subtrees); - for (int i = 0; i < len; ++i) { - for (int j = 0; j < subtrees.size(); ++j) { - if (py::isinstance(subtrees[j])) { - items[j] = py::cast(subtrees[j])[i]; - } else { - items[j] = subtrees[j]; - } - } - l[i] = recurse(items); - } - return py::cast(l); - } else if (py::isinstance(subtrees[0])) { - std::vector items(subtrees.size()); - validate_subtrees(subtrees); - py::dict d; - for (auto item : py::cast(subtrees[0])) { - for (int j = 0; j < subtrees.size(); ++j) { - if (py::isinstance(subtrees[j])) { - auto subdict = py::cast(subtrees[j]); - if (!subdict.contains(item.first)) { - throw std::invalid_argument( - "[tree_map] Tree is not a valid prefix tree of the first tree."); - } - items[j] = subdict[item.first]; - } else { - items[j] = subtrees[j]; - } - } - d[item.first] = recurse(items); - } - return py::cast(d); - } else { - return transform(subtrees); - } - }; - return recurse(trees); -} - -py::object tree_map( - py::object tree, - std::function transform) { - return tree_map({tree}, [&](std::vector inputs) { - return transform(inputs[0]); - }); -} - -void tree_visit_update( - py::object tree, - std::function visitor) { - std::function recurse; - recurse = [&](py::handle subtree) { - if (py::isinstance(subtree)) { - auto l = py::cast(subtree); - for (int i = 0; i < l.size(); ++i) { - l[i] = recurse(l[i]); - } - return py::cast(l); - } else if (py::isinstance(subtree)) { - for (auto item : subtree) { - recurse(item); - } - return py::cast(subtree); - } else if (py::isinstance(subtree)) { - auto d = py::cast(subtree); - for (auto item : d) { - d[item.first] = recurse(item.second); - } - return py::cast(d); - } else if (py::isinstance(subtree)) { - return visitor(subtree); - } else { - return py::cast(subtree); - } - }; - recurse(tree); -} - -// Fill a pytree (recursive dict or list of dict or list) -// in place with the given arrays -// Non dict or list nodes are ignored -void tree_fill(py::object& tree, const std::vector& values) { - size_t index = 0; - tree_visit_update( - tree, [&](py::handle node) { return py::cast(values[index++]); }); -} - -// Replace all the arrays from the src values with the dst values in the tree -void tree_replace( - py::object& tree, - const std::vector& src, - const std::vector& dst) { - std::unordered_map src_to_dst; - for (int i = 0; i < src.size(); ++i) { - src_to_dst.insert({src[i].id(), dst[i]}); - } - tree_visit_update(tree, [&](py::handle node) { - auto arr = py::cast(node); - if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { - return py::cast(it->second); - } - return py::cast(arr); - }); -} - -std::vector tree_flatten(py::object tree, bool strict = true) { - std::vector flat_tree; - - tree_visit(tree, [&](py::handle obj) { - if (py::isinstance(obj)) { - flat_tree.push_back(py::cast(obj)); - } else if (strict) { - throw std::invalid_argument( - "[tree_flatten] The argument should contain only arrays"); - } - }); - - return flat_tree; -} - -py::object tree_unflatten( - py::object tree, - const std::vector& values, - int index = 0) { - return tree_map(tree, [&](py::handle obj) { - if (py::isinstance(obj)) { - return py::cast(values[index++]); - } else { - return py::cast(obj); - } - }); -} - -py::object structure_sentinel() { - static py::object sentinel; - - if (sentinel.ptr() == nullptr) { - sentinel = py::capsule(&sentinel); - // probably not needed but this should make certain that we won't ever - // delete the sentinel - sentinel.inc_ref(); - } - - return sentinel; -} - -std::pair, py::object> tree_flatten_with_structure( - py::object tree, - bool strict = true) { - auto sentinel = structure_sentinel(); - std::vector flat_tree; - auto structure = tree_map( - tree, - [&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) { - if (py::isinstance(obj)) { - flat_tree.push_back(py::cast(obj)); - return sentinel; - } else if (!strict) { - return py::cast(obj); - } else { - throw std::invalid_argument( - "[tree_flatten] The argument should contain only arrays"); - } - }); - - return {flat_tree, structure}; -} - -py::object tree_unflatten_from_structure( - py::object structure, - const std::vector& values, - int index = 0) { - auto sentinel = structure_sentinel(); - return tree_map(structure, [&](py::handle obj) { - if (obj.is(sentinel)) { - return py::cast(values[index++]); - } else { - return py::cast(obj); - } - }); -} - auto validate_argnums_argnames( const std::optional& argnums, const StrOrVec& argnames) { @@ -582,9 +343,69 @@ struct PyCompiledFun { }; py::object operator()(const py::args& args, const py::kwargs& kwargs) { - auto inputs = tree_flatten(args, false); + // Flat array inputs + std::vector inputs; - auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()]( + // Compilation constants which includes the tree structure of the arguments + std::vector constants; + + // Reserve some large primes to signify the presence of an array, a list or + // a dict in order to encode the structure of the pytree. We choose primes + // to reduce slightly the chances of these numbers occuring by a + // multiplication as values in the constants list. + constexpr uint64_t array_identifier = 18446744073709551557UL; + constexpr uint64_t list_identifier = 18446744073709551533UL; + constexpr uint64_t dict_identifier = 18446744073709551521UL; + + // Flatten the tree with hashed constants and structure + std::function recurse; + recurse = [&](py::handle obj) { + if (py::isinstance(obj)) { + auto l = py::cast(obj); + constants.push_back(list_identifier); + for (int i = 0; i < l.size(); ++i) { + recurse(l[i]); + } + } else if (py::isinstance(obj)) { + auto l = py::cast(obj); + constants.push_back(list_identifier); + for (auto item : obj) { + recurse(item); + } + } else if (py::isinstance(obj)) { + auto d = py::cast(obj); + constants.push_back(dict_identifier); + for (auto item : d) { + auto r = py::hash(item.first); + constants.push_back(*reinterpret_cast(&r)); + recurse(item.second); + } + } else if (py::isinstance(obj)) { + inputs.push_back(py::cast(obj)); + constants.push_back(array_identifier); + } else if (py::isinstance(obj)) { + auto r = py::hash(obj); + constants.push_back(*reinterpret_cast(&r)); + } else if (py::isinstance(obj)) { + auto r = obj.cast(); + constants.push_back(*reinterpret_cast(&r)); + } else if (py::isinstance(obj)) { + auto r = obj.cast(); + constants.push_back(*reinterpret_cast(&r)); + } else { + std::ostringstream msg; + msg << "[compile] Function arguments must be trees of arrays " + << "or constants (floats, ints, or strings), but received " + << "type " << obj.get_type() << "."; + throw std::invalid_argument(msg.str()); + } + }; + + recurse(args); + int num_args = inputs.size(); + recurse(kwargs); + + auto compile_fun = [this, &args, &kwargs, num_args]( const std::vector& a) { // Put tracers into captured inputs std::vector flat_in_captures; @@ -619,14 +440,6 @@ struct PyCompiledFun { return outputs; }; - { - auto flat_kwargs = tree_flatten(kwargs, false); - inputs.insert( - inputs.end(), - std::make_move_iterator(flat_kwargs.begin()), - std::make_move_iterator(flat_kwargs.end())); - } - if (!py::isinstance(captured_inputs)) { auto flat_in_captures = tree_flatten(captured_inputs, false); inputs.insert( @@ -635,36 +448,6 @@ struct PyCompiledFun { std::make_move_iterator(flat_in_captures.end())); } - // Collect the compilation constants - std::vector constants; - auto value_hash = [](py::handle o) -> std::optional { - // Consider expanding tuples to their contents including start and end - // ids - if (py::isinstance(o) || py::isinstance(o)) { - auto r = py::hash(o); - return *reinterpret_cast(&r); - } else if (py::isinstance(o)) { - auto r = o.cast(); - return *reinterpret_cast(&r); - } else if (py::isinstance(o)) { - auto r = o.cast(); - return *reinterpret_cast(&r); - } else { - return std::nullopt; - } - }; - for (int i = 0; i < args.size(); i++) { - if (auto h = value_hash(args[i]); h.has_value()) { - constants.push_back(*h); - } - } - for (auto& pair : kwargs) { - if (auto h = value_hash(pair.second); h.has_value()) { - constants.push_back(*value_hash(pair.first)); - constants.push_back(*h); - } - } - // Compile and call auto outputs = detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); diff --git a/python/src/trees.cpp b/python/src/trees.cpp new file mode 100644 index 000000000..bd2c3f975 --- /dev/null +++ b/python/src/trees.cpp @@ -0,0 +1,243 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "python/src/trees.h" + +void tree_visit(py::object tree, std::function visitor) { + std::function recurse; + recurse = [&](py::handle subtree) { + if (py::isinstance(subtree) || + py::isinstance(subtree)) { + for (auto item : subtree) { + recurse(item); + } + } else if (py::isinstance(subtree)) { + for (auto item : py::cast(subtree)) { + recurse(item.second); + } + } else { + visitor(subtree); + } + }; + + recurse(tree); +} + +template +void validate_subtrees(const std::vector& subtrees) { + int len = py::cast(subtrees[0]).size(); + for (auto& subtree : subtrees) { + if ((py::isinstance(subtree) && py::cast(subtree).size() != len) || + py::isinstance(subtree) || py::isinstance(subtree)) { + throw std::invalid_argument( + "[tree_map] Additional input tree is not a valid prefix of the first tree."); + } + } +} + +py::object tree_map( + const std::vector& trees, + std::function&)> transform) { + std::function&)> recurse; + + recurse = [&](const std::vector& subtrees) { + if (py::isinstance(subtrees[0])) { + py::list l; + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + for (int i = 0; i < py::cast(subtrees[0]).size(); ++i) { + for (int j = 0; j < subtrees.size(); ++j) { + if (py::isinstance(subtrees[j])) { + items[j] = py::cast(subtrees[j])[i]; + } else { + items[j] = subtrees[j]; + } + } + l.append(recurse(items)); + } + return py::cast(l); + } else if (py::isinstance(subtrees[0])) { + // Check the rest of the subtrees + std::vector items(subtrees.size()); + int len = py::cast(subtrees[0]).size(); + py::tuple l(len); + validate_subtrees(subtrees); + for (int i = 0; i < len; ++i) { + for (int j = 0; j < subtrees.size(); ++j) { + if (py::isinstance(subtrees[j])) { + items[j] = py::cast(subtrees[j])[i]; + } else { + items[j] = subtrees[j]; + } + } + l[i] = recurse(items); + } + return py::cast(l); + } else if (py::isinstance(subtrees[0])) { + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + py::dict d; + for (auto item : py::cast(subtrees[0])) { + for (int j = 0; j < subtrees.size(); ++j) { + if (py::isinstance(subtrees[j])) { + auto subdict = py::cast(subtrees[j]); + if (!subdict.contains(item.first)) { + throw std::invalid_argument( + "[tree_map] Tree is not a valid prefix tree of the first tree."); + } + items[j] = subdict[item.first]; + } else { + items[j] = subtrees[j]; + } + } + d[item.first] = recurse(items); + } + return py::cast(d); + } else { + return transform(subtrees); + } + }; + return recurse(trees); +} + +py::object tree_map( + py::object tree, + std::function transform) { + return tree_map({tree}, [&](std::vector inputs) { + return transform(inputs[0]); + }); +} + +void tree_visit_update( + py::object tree, + std::function visitor) { + std::function recurse; + recurse = [&](py::handle subtree) { + if (py::isinstance(subtree)) { + auto l = py::cast(subtree); + for (int i = 0; i < l.size(); ++i) { + l[i] = recurse(l[i]); + } + return py::cast(l); + } else if (py::isinstance(subtree)) { + for (auto item : subtree) { + recurse(item); + } + return py::cast(subtree); + } else if (py::isinstance(subtree)) { + auto d = py::cast(subtree); + for (auto item : d) { + d[item.first] = recurse(item.second); + } + return py::cast(d); + } else if (py::isinstance(subtree)) { + return visitor(subtree); + } else { + return py::cast(subtree); + } + }; + recurse(tree); +} + +// Fill a pytree (recursive dict or list of dict or list) +// in place with the given arrays +// Non dict or list nodes are ignored +void tree_fill(py::object& tree, const std::vector& values) { + size_t index = 0; + tree_visit_update( + tree, [&](py::handle node) { return py::cast(values[index++]); }); +} + +// Replace all the arrays from the src values with the dst values in the tree +void tree_replace( + py::object& tree, + const std::vector& src, + const std::vector& dst) { + std::unordered_map src_to_dst; + for (int i = 0; i < src.size(); ++i) { + src_to_dst.insert({src[i].id(), dst[i]}); + } + tree_visit_update(tree, [&](py::handle node) { + auto arr = py::cast(node); + if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { + return py::cast(it->second); + } + return py::cast(arr); + }); +} + +std::vector tree_flatten(py::object tree, bool strict /* = true */) { + std::vector flat_tree; + + tree_visit(tree, [&](py::handle obj) { + if (py::isinstance(obj)) { + flat_tree.push_back(py::cast(obj)); + } else if (strict) { + throw std::invalid_argument( + "[tree_flatten] The argument should contain only arrays"); + } + }); + + return flat_tree; +} + +py::object tree_unflatten( + py::object tree, + const std::vector& values, + int index /* = 0 */) { + return tree_map(tree, [&](py::handle obj) { + if (py::isinstance(obj)) { + return py::cast(values[index++]); + } else { + return py::cast(obj); + } + }); +} + +py::object structure_sentinel() { + static py::object sentinel; + + if (sentinel.ptr() == nullptr) { + sentinel = py::capsule(&sentinel); + // probably not needed but this should make certain that we won't ever + // delete the sentinel + sentinel.inc_ref(); + } + + return sentinel; +} + +std::pair, py::object> tree_flatten_with_structure( + py::object tree, + bool strict /* = true */) { + auto sentinel = structure_sentinel(); + std::vector flat_tree; + auto structure = tree_map( + tree, + [&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) { + if (py::isinstance(obj)) { + flat_tree.push_back(py::cast(obj)); + return sentinel; + } else if (!strict) { + return py::cast(obj); + } else { + throw std::invalid_argument( + "[tree_flatten] The argument should contain only arrays"); + } + }); + + return {flat_tree, structure}; +} + +py::object tree_unflatten_from_structure( + py::object structure, + const std::vector& values, + int index /* = 0 */) { + auto sentinel = structure_sentinel(); + return tree_map(structure, [&](py::handle obj) { + if (obj.is(sentinel)) { + return py::cast(values[index++]); + } else { + return py::cast(obj); + } + }); +} diff --git a/python/src/trees.h b/python/src/trees.h new file mode 100644 index 000000000..bb44f2320 --- /dev/null +++ b/python/src/trees.h @@ -0,0 +1,60 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once +#include +#include + +#include "mlx/array.h" + +namespace py = pybind11; +using namespace mlx::core; + +void tree_visit(py::object tree, std::function visitor); + +py::object tree_map( + const std::vector& trees, + std::function&)> transform); + +py::object tree_map( + py::object tree, + std::function transform); + +void tree_visit_update( + py::object tree, + std::function visitor); + +/** + * Fill a pytree (recursive dict or list of dict or list) in place with the + * given arrays. */ +void tree_fill(py::object& tree, const std::vector& values); + +/** + * Replace all the arrays from the src values with the dst values in the + * tree. + */ +void tree_replace( + py::object& tree, + const std::vector& src, + const std::vector& dst); + +/** + * Flatten a tree into a vector of arrays. If strict is true, then the + * function will throw if the tree contains a leaf which is not an array. + */ +std::vector tree_flatten(py::object tree, bool strict = true); + +/** + * Unflatten a tree from a vector of arrays. + */ +py::object tree_unflatten( + py::object tree, + const std::vector& values, + int index = 0); + +std::pair, py::object> tree_flatten_with_structure( + py::object tree, + bool strict = true); + +py::object tree_unflatten_from_structure( + py::object structure, + const std::vector& values, + int index = 0); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 2d0b22cdd..18f523211 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -539,6 +539,48 @@ class TestCompile(mlx_tests.MLXTestCase): z = fun(mx.array(1), "two") self.assertEqual(z.item(), 3) + # Test nested constant + @partial(mx.compile) + def fun(x, y): + if y[0][0] == 1: + return x + 1 + else: + return x + 2 + + z = fun(mx.array(1), [[1]]) + self.assertEqual(z.item(), 2) + + z = fun(mx.array(1), [[0]]) + self.assertEqual(z.item(), 3) + + @partial(mx.compile) + def fun(x, a, b): + for ai in a: + for bi in b: + x = bi * x + ai + return x + + z = fun(mx.array(1), [1, 1], [2]) + self.assertEqual(z.item(), 7) + + z = fun(mx.array(1), [1], [1, 2]) + self.assertEqual(z.item(), 5) + + counter = [0] + + @partial(mx.compile) + def fun(x, y): + counter[0] += 1 + return x + y + + z = fun(mx.array(1), 1) + self.assertEqual(z.item(), 2) + + z = fun(1, mx.array(1)) + self.assertEqual(z.item(), 2) + + self.assertEqual(counter[0], 2) + def test_compile_inf(self): @mx.compile @@ -548,6 +590,21 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(mx.array([0.0])) self.assertEqual(out.item(), False) + def test_unsupported_input_types(self): + + class MyClass: + value = 1 + + @mx.compile + def fun(x, y): + return x + y.value + + with self.assertRaises(ValueError): + out = fun(mx.array(0.0), MyClass()) + + with self.assertRaises(ValueError): + out = fun(mx.array(0.0), y=MyClass()) + if __name__ == "__main__": unittest.main()