mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Remove "using namespace mlx::core" in python/src (#1689)
This commit is contained in:
@@ -188,7 +188,7 @@ void tree_visit_update(
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return nb::cast<nb::object>(d);
|
||||
} else if (nb::isinstance<array>(subtree)) {
|
||||
} else if (nb::isinstance<mx::array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return nb::cast<nb::object>(subtree);
|
||||
@@ -200,7 +200,7 @@ void tree_visit_update(
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(nb::object& tree, const std::vector<array>& values) {
|
||||
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
|
||||
@@ -209,14 +209,14 @@ void tree_fill(nb::object& tree, const std::vector<array>& values) {
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
nb::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
const std::vector<mx::array>& src,
|
||||
const std::vector<mx::array>& dst) {
|
||||
std::unordered_map<uintptr_t, mx::array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](nb::handle node) {
|
||||
auto arr = nb::cast<array>(node);
|
||||
auto arr = nb::cast<mx::array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return nb::cast(it->second);
|
||||
}
|
||||
@@ -224,12 +224,12 @@ void tree_replace(
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
std::vector<array> flat_tree;
|
||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
std::vector<mx::array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<array>(obj));
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<mx::array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
@@ -241,10 +241,10 @@ std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
|
||||
nb::object tree_unflatten(
|
||||
nb::object tree,
|
||||
const std::vector<array>& values,
|
||||
const std::vector<mx::array>& values,
|
||||
int index /* = 0 */) {
|
||||
return tree_map(tree, [&](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
return nb::cast(values[index++]);
|
||||
} else {
|
||||
return nb::cast<nb::object>(obj);
|
||||
@@ -265,16 +265,16 @@ nb::object structure_sentinel() {
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
||||
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
|
||||
nb::object tree,
|
||||
bool strict /* = true */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
std::vector<mx::array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<array>(obj));
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<mx::array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return nb::cast<nb::object>(obj);
|
||||
@@ -289,7 +289,7 @@ std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
||||
|
||||
nb::object tree_unflatten_from_structure(
|
||||
nb::object structure,
|
||||
const std::vector<array>& values,
|
||||
const std::vector<mx::array>& values,
|
||||
int index /* = 0 */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](nb::handle obj) {
|
||||
|
Reference in New Issue
Block a user