Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -188,7 +188,7 @@ void tree_visit_update(
d[item.first] = recurse(item.second);
}
return nb::cast<nb::object>(d);
} else if (nb::isinstance<array>(subtree)) {
} else if (nb::isinstance<mx::array>(subtree)) {
return visitor(subtree);
} else {
return nb::cast<nb::object>(subtree);
@@ -200,7 +200,7 @@ void tree_visit_update(
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(nb::object& tree, const std::vector<array>& values) {
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
@@ -209,14 +209,14 @@ void tree_fill(nb::object& tree, const std::vector<array>& values) {
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
nb::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
const std::vector<mx::array>& src,
const std::vector<mx::array>& dst) {
std::unordered_map<uintptr_t, mx::array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](nb::handle node) {
auto arr = nb::cast<array>(node);
auto arr = nb::cast<mx::array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return nb::cast(it->second);
}
@@ -224,12 +224,12 @@ void tree_replace(
});
}
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<array> flat_tree;
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<mx::array> flat_tree;
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj));
if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<mx::array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
@@ -241,10 +241,10 @@ std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
nb::object tree_unflatten(
nb::object tree,
const std::vector<array>& values,
const std::vector<mx::array>& values,
int index /* = 0 */) {
return tree_map(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
if (nb::isinstance<mx::array>(obj)) {
return nb::cast(values[index++]);
} else {
return nb::cast<nb::object>(obj);
@@ -265,16 +265,16 @@ nb::object structure_sentinel() {
return sentinel;
}
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
nb::object tree,
bool strict /* = true */) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
std::vector<mx::array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj));
if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<mx::array>(obj));
return sentinel;
} else if (!strict) {
return nb::cast<nb::object>(obj);
@@ -289,7 +289,7 @@ std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
nb::object tree_unflatten_from_structure(
nb::object structure,
const std::vector<array>& values,
const std::vector<mx::array>& values,
int index /* = 0 */) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](nb::handle obj) {