Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -20,9 +20,12 @@
#include "mlx/utils.h"
#include "python/src/trees.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
// Needed for printing shapes and strides.
using mx::operator<<;
using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
@@ -108,7 +111,7 @@ auto py_value_and_grad(
}
// Collect the arrays
std::vector<array> arrays;
std::vector<mx::array> arrays;
std::vector<int> counts(1, 0);
for (auto i : argnums) {
auto argsi = tree_flatten(args[i]);
@@ -127,7 +130,7 @@ auto py_value_and_grad(
// value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
nb::object py_value_out;
auto value_and_grads = value_and_grad(
auto value_and_grads = mx::value_and_grad(
[&fun,
&args,
&kwargs,
@@ -136,7 +139,7 @@ auto py_value_and_grad(
&counts,
&py_value_out,
&error_msg_tag,
scalar_func_only](const std::vector<array>& a) {
scalar_func_only](const std::vector<mx::array>& a) {
// Copy the arguments
nb::list args_cpy;
nb::kwargs kwargs_cpy = nb::kwargs();
@@ -165,7 +168,7 @@ auto py_value_and_grad(
py_value_out = fun(*args_cpy, **kwargs_cpy);
// Validate the return value of the python function
if (!nb::isinstance<array>(py_value_out)) {
if (!nb::isinstance<mx::array>(py_value_out)) {
if (scalar_func_only) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
@@ -193,7 +196,7 @@ auto py_value_and_grad(
<< "we got an empty tuple.";
throw std::invalid_argument(msg.str());
}
if (!nb::isinstance<array>(ret[0])) {
if (!nb::isinstance<mx::array>(ret[0])) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
@@ -275,12 +278,12 @@ auto py_vmap(
{tree, axes},
[&flat_axes, &encountered_tuple, output_axes](
const std::vector<nb::object>& inputs) {
if (nb::isinstance<array>(inputs[0])) {
if (nb::isinstance<mx::array>(inputs[0])) {
if (inputs[1].is_none()) {
flat_axes.push_back(-1);
} else if (nb::isinstance<nb::int_>(inputs[1])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
const array& x = nb::cast<array>(inputs[0]);
const mx::array& x = nb::cast<mx::array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
@@ -297,7 +300,7 @@ auto py_vmap(
auto l = nb::cast<nb::tuple>(inputs[1]);
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
const array& x = nb::cast<array>(inputs[0]);
const mx::array& x = nb::cast<mx::array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
@@ -323,7 +326,7 @@ auto py_vmap(
"[vmap] The arguments should contain only arrays");
}
});
if (encountered_tuple && !nb::isinstance<array>(tree)) {
if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
return flat_axes;
@@ -339,7 +342,7 @@ auto py_vmap(
nb::object py_outputs;
auto vmap_fn =
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
// Call the python function
py_outputs = fun(*tree_unflatten(args, a));
@@ -348,12 +351,12 @@ auto py_vmap(
};
auto [trace_inputs, trace_outputs] =
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
// Perform the vmap
auto outputs = detail::vmap_replace(
auto outputs = mx::detail::vmap_replace(
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
// Put the outputs back in the container
@@ -401,7 +404,7 @@ struct PyCompiledFun {
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
// Flat array inputs
std::vector<array> inputs;
std::vector<mx::array> inputs;
// Compilation constants which includes the tree structure of the arguments
std::vector<uint64_t> constants;
@@ -437,8 +440,8 @@ struct PyCompiledFun {
constants.push_back(nb::cast<int64_t>(r));
recurse(item.second);
}
} else if (nb::isinstance<array>(obj)) {
inputs.push_back(nb::cast<array>(obj));
} else if (nb::isinstance<mx::array>(obj)) {
inputs.push_back(nb::cast<mx::array>(obj));
constants.push_back(array_identifier);
} else if (nb::isinstance<nb::str>(obj)) {
auto r = obj.attr("__hash__")();
@@ -461,10 +464,10 @@ struct PyCompiledFun {
int num_args = inputs.size();
recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) {
const std::vector<mx::array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
std::vector<array> trace_captures;
std::vector<mx::array> flat_in_captures;
std::vector<mx::array> trace_captures;
if (!captured_inputs.is_none()) {
flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert(
@@ -505,9 +508,9 @@ struct PyCompiledFun {
// Compile and call
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!captured_outputs.is_none()) {
std::vector<array> captures(
std::vector<mx::array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end()));
tree_fill(captured_outputs, captures);
@@ -526,7 +529,7 @@ struct PyCompiledFun {
nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
mx::detail::compile_erase(fun_id);
fun.release().dec_ref();
captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref();
@@ -561,7 +564,7 @@ class PyCheckpointedFun {
args_structure_.release().dec_ref();
}
std::vector<array> operator()(const std::vector<array>& inputs) {
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
auto args = nb::cast<nb::tuple>(
tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] =
@@ -579,7 +582,7 @@ class PyCheckpointedFun {
auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false);
auto outputs = checkpoint(
auto outputs = mx::checkpoint(
InnerFunction(fun_, args_structure, output_structure))(inputs);
return tree_unflatten_from_structure(*output_structure, outputs);
@@ -660,12 +663,12 @@ class PyCustomFunction {
}
}
std::vector<array> operator()(const std::vector<array>& inputs) {
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
nb::gil_scoped_acquire gil;
auto new_inputs = nb::cast<nb::tuple>(
tree_unflatten_from_structure(input_structure_, inputs));
std::vector<array> outputs;
std::vector<mx::array> outputs;
std::tie(outputs, *output_structure_) =
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
return outputs;
@@ -694,10 +697,10 @@ class PyCustomFunction {
}
}
std::vector<array> operator()(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<array>& outputs) {
std::vector<mx::array> operator()(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<mx::array>& outputs) {
nb::gil_scoped_acquire gil;
auto new_inputs = nb::cast<nb::tuple>(
@@ -734,9 +737,9 @@ class PyCustomFunction {
input_structure_.release().dec_ref();
}
std::vector<array> operator()(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> operator()(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) {
nb::gil_scoped_acquire gil;
@@ -759,7 +762,7 @@ class PyCustomFunction {
int tangent_index = 0;
auto new_tangents =
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
if (nb::isinstance<array>(element) &&
if (nb::isinstance<mx::array>(element) &&
have_tangents[array_index++]) {
return nb::cast(tangents[tangent_index++]);
} else {
@@ -789,8 +792,8 @@ class PyCustomFunction {
input_structure_.release().dec_ref();
}
std::pair<std::vector<array>, std::vector<int>> operator()(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) {
nb::gil_scoped_acquire gil;
@@ -807,7 +810,7 @@ class PyCustomFunction {
auto new_axes =
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
int axis = axes[arr_index++];
if (nb::isinstance<array>(element) && axis >= 0) {
if (nb::isinstance<mx::array>(element) && axis >= 0) {
return nb::cast(axis);
} else {
return nb::none();
@@ -831,11 +834,11 @@ class PyCustomFunction {
"[custom vmap] Vmap function should return a tuple with 2 items.");
}
std::vector<array> outputs;
std::vector<mx::array> outputs;
std::vector<int> output_axes;
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
if (nb::isinstance<array>(objects[0])) {
outputs.push_back(nb::cast<array>(objects[0]));
if (nb::isinstance<mx::array>(objects[0])) {
outputs.push_back(nb::cast<mx::array>(objects[0]));
output_axes.push_back(
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
}
@@ -852,7 +855,7 @@ class PyCustomFunction {
}
// Extract the inputs and their structure in capturable vars
std::vector<array> input_arrays;
std::vector<mx::array> input_arrays;
nb::object input_structure;
auto full_args = nb::make_tuple(args, kwargs);
std::tie(input_arrays, input_structure) =
@@ -864,7 +867,7 @@ class PyCustomFunction {
// Make a function that calls fun_ in the forward pass and vjp_ in the
// backward pass. Then call it immediately and return the results.
auto f = custom_function(
auto f = mx::custom_function(
InnerFunction(fun_, input_structure, output_structure),
make_vjp_function(input_structure, output_structure),
make_jvp_function(input_structure),
@@ -1044,7 +1047,7 @@ void init_transforms(nb::module_& m) {
m.def(
"eval",
[](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false);
std::vector<mx::array> arrays = tree_flatten(args, false);
{
nb::gil_scoped_release nogil;
eval(arrays);
@@ -1064,7 +1067,7 @@ void init_transforms(nb::module_& m) {
m.def(
"async_eval",
[](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false);
std::vector<mx::array> arrays = tree_flatten(args, false);
{
nb::gil_scoped_release nogil;
async_eval(arrays);
@@ -1100,14 +1103,14 @@ void init_transforms(nb::module_& m) {
m.def(
"jvp",
[](const nb::callable& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents) {
auto vfun = [&fun](const std::vector<mx::array>& primals) {
auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) {
return std::vector<array>{nb::cast<array>(out)};
if (nb::isinstance<mx::array>(out)) {
return std::vector<mx::array>{nb::cast<mx::array>(out)};
} else {
return nb::cast<std::vector<array>>(out);
return nb::cast<std::vector<mx::array>>(out);
}
};
return jvp(vfun, primals, tangents);
@@ -1139,14 +1142,14 @@ void init_transforms(nb::module_& m) {
m.def(
"vjp",
[](const nb::callable& fun,
const std::vector<array>& primals,
const std::vector<array>& cotangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents) {
auto vfun = [&fun](const std::vector<mx::array>& primals) {
auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) {
return std::vector<array>{nb::cast<array>(out)};
if (nb::isinstance<mx::array>(out)) {
return std::vector<mx::array>{nb::cast<mx::array>(out)};
} else {
return nb::cast<std::vector<array>>(out);
return nb::cast<std::vector<mx::array>>(out);
}
};
return vjp(vfun, primals, cotangents);
@@ -1312,7 +1315,7 @@ void init_transforms(nb::module_& m) {
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<array> arrays = tree_flatten(args);
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays);
@@ -1399,14 +1402,14 @@ void init_transforms(nb::module_& m) {
)pbdoc");
m.def(
"disable_compile",
&disable_compile,
&mx::disable_compile,
R"pbdoc(
Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc");
m.def(
"enable_compile",
&enable_compile,
&mx::enable_compile,
R"pbdoc(
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set.
@@ -1420,6 +1423,6 @@ void init_transforms(nb::module_& m) {
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() {
tree_cache().clear();
detail::compile_clear_cache();
mx::detail::compile_clear_cache();
}));
}