Properly handle negative axes in python vmap (#944)

This commit is contained in:
Angelos Katharopoulos
2024-04-02 18:07:23 -07:00
committed by GitHub
parent 741eb28443
commit 3fc993f82d
4 changed files with 147 additions and 41 deletions

View File

@@ -7,6 +7,9 @@
namespace nb = nanobind;
using namespace mlx::core;
void tree_visit(
const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor);
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
nb::object tree_map(