mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Remove "using namespace mlx::core" in python/src (#1689)
This commit is contained in:
@@ -20,9 +20,12 @@
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
// Needed for printing shapes and strides.
|
||||
using mx::operator<<;
|
||||
|
||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||
@@ -108,7 +111,7 @@ auto py_value_and_grad(
|
||||
}
|
||||
|
||||
// Collect the arrays
|
||||
std::vector<array> arrays;
|
||||
std::vector<mx::array> arrays;
|
||||
std::vector<int> counts(1, 0);
|
||||
for (auto i : argnums) {
|
||||
auto argsi = tree_flatten(args[i]);
|
||||
@@ -127,7 +130,7 @@ auto py_value_and_grad(
|
||||
// value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
nb::object py_value_out;
|
||||
auto value_and_grads = value_and_grad(
|
||||
auto value_and_grads = mx::value_and_grad(
|
||||
[&fun,
|
||||
&args,
|
||||
&kwargs,
|
||||
@@ -136,7 +139,7 @@ auto py_value_and_grad(
|
||||
&counts,
|
||||
&py_value_out,
|
||||
&error_msg_tag,
|
||||
scalar_func_only](const std::vector<array>& a) {
|
||||
scalar_func_only](const std::vector<mx::array>& a) {
|
||||
// Copy the arguments
|
||||
nb::list args_cpy;
|
||||
nb::kwargs kwargs_cpy = nb::kwargs();
|
||||
@@ -165,7 +168,7 @@ auto py_value_and_grad(
|
||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!nb::isinstance<array>(py_value_out)) {
|
||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||
if (scalar_func_only) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
@@ -193,7 +196,7 @@ auto py_value_and_grad(
|
||||
<< "we got an empty tuple.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!nb::isinstance<array>(ret[0])) {
|
||||
if (!nb::isinstance<mx::array>(ret[0])) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
@@ -275,12 +278,12 @@ auto py_vmap(
|
||||
{tree, axes},
|
||||
[&flat_axes, &encountered_tuple, output_axes](
|
||||
const std::vector<nb::object>& inputs) {
|
||||
if (nb::isinstance<array>(inputs[0])) {
|
||||
if (nb::isinstance<mx::array>(inputs[0])) {
|
||||
if (inputs[1].is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (nb::isinstance<nb::int_>(inputs[1])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
const mx::array& x = nb::cast<mx::array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
@@ -297,7 +300,7 @@ auto py_vmap(
|
||||
auto l = nb::cast<nb::tuple>(inputs[1]);
|
||||
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
const mx::array& x = nb::cast<mx::array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
@@ -323,7 +326,7 @@ auto py_vmap(
|
||||
"[vmap] The arguments should contain only arrays");
|
||||
}
|
||||
});
|
||||
if (encountered_tuple && !nb::isinstance<array>(tree)) {
|
||||
if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
return flat_axes;
|
||||
@@ -339,7 +342,7 @@ auto py_vmap(
|
||||
nb::object py_outputs;
|
||||
|
||||
auto vmap_fn =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
@@ -348,12 +351,12 @@ auto py_vmap(
|
||||
};
|
||||
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
|
||||
|
||||
// Perform the vmap
|
||||
auto outputs = detail::vmap_replace(
|
||||
auto outputs = mx::detail::vmap_replace(
|
||||
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
|
||||
|
||||
// Put the outputs back in the container
|
||||
@@ -401,7 +404,7 @@ struct PyCompiledFun {
|
||||
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
// Flat array inputs
|
||||
std::vector<array> inputs;
|
||||
std::vector<mx::array> inputs;
|
||||
|
||||
// Compilation constants which includes the tree structure of the arguments
|
||||
std::vector<uint64_t> constants;
|
||||
@@ -437,8 +440,8 @@ struct PyCompiledFun {
|
||||
constants.push_back(nb::cast<int64_t>(r));
|
||||
recurse(item.second);
|
||||
}
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
inputs.push_back(nb::cast<array>(obj));
|
||||
} else if (nb::isinstance<mx::array>(obj)) {
|
||||
inputs.push_back(nb::cast<mx::array>(obj));
|
||||
constants.push_back(array_identifier);
|
||||
} else if (nb::isinstance<nb::str>(obj)) {
|
||||
auto r = obj.attr("__hash__")();
|
||||
@@ -461,10 +464,10 @@ struct PyCompiledFun {
|
||||
int num_args = inputs.size();
|
||||
recurse(kwargs);
|
||||
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||
const std::vector<array>& a) {
|
||||
const std::vector<mx::array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
std::vector<mx::array> flat_in_captures;
|
||||
std::vector<mx::array> trace_captures;
|
||||
if (!captured_inputs.is_none()) {
|
||||
flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
trace_captures.insert(
|
||||
@@ -505,9 +508,9 @@ struct PyCompiledFun {
|
||||
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!captured_outputs.is_none()) {
|
||||
std::vector<array> captures(
|
||||
std::vector<mx::array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
std::make_move_iterator(outputs.end()));
|
||||
tree_fill(captured_outputs, captures);
|
||||
@@ -526,7 +529,7 @@ struct PyCompiledFun {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
mx::detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
captured_inputs.release().dec_ref();
|
||||
captured_outputs.release().dec_ref();
|
||||
@@ -561,7 +564,7 @@ class PyCheckpointedFun {
|
||||
args_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||
auto args = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(args_structure_, inputs));
|
||||
auto [outputs, output_structure] =
|
||||
@@ -579,7 +582,7 @@ class PyCheckpointedFun {
|
||||
auto [inputs, args_structure] =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
auto outputs = checkpoint(
|
||||
auto outputs = mx::checkpoint(
|
||||
InnerFunction(fun_, args_structure, output_structure))(inputs);
|
||||
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
@@ -660,12 +663,12 @@ class PyCustomFunction {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, inputs));
|
||||
std::vector<array> outputs;
|
||||
std::vector<mx::array> outputs;
|
||||
std::tie(outputs, *output_structure_) =
|
||||
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
|
||||
return outputs;
|
||||
@@ -694,10 +697,10 @@ class PyCustomFunction {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<mx::array> operator()(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<mx::array>& outputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
@@ -734,9 +737,9 @@ class PyCustomFunction {
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> operator()(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
@@ -759,7 +762,7 @@ class PyCustomFunction {
|
||||
int tangent_index = 0;
|
||||
auto new_tangents =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
if (nb::isinstance<array>(element) &&
|
||||
if (nb::isinstance<mx::array>(element) &&
|
||||
have_tangents[array_index++]) {
|
||||
return nb::cast(tangents[tangent_index++]);
|
||||
} else {
|
||||
@@ -789,8 +792,8 @@ class PyCustomFunction {
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> operator()(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
@@ -807,7 +810,7 @@ class PyCustomFunction {
|
||||
auto new_axes =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
int axis = axes[arr_index++];
|
||||
if (nb::isinstance<array>(element) && axis >= 0) {
|
||||
if (nb::isinstance<mx::array>(element) && axis >= 0) {
|
||||
return nb::cast(axis);
|
||||
} else {
|
||||
return nb::none();
|
||||
@@ -831,11 +834,11 @@ class PyCustomFunction {
|
||||
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
std::vector<mx::array> outputs;
|
||||
std::vector<int> output_axes;
|
||||
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
|
||||
if (nb::isinstance<array>(objects[0])) {
|
||||
outputs.push_back(nb::cast<array>(objects[0]));
|
||||
if (nb::isinstance<mx::array>(objects[0])) {
|
||||
outputs.push_back(nb::cast<mx::array>(objects[0]));
|
||||
output_axes.push_back(
|
||||
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
|
||||
}
|
||||
@@ -852,7 +855,7 @@ class PyCustomFunction {
|
||||
}
|
||||
|
||||
// Extract the inputs and their structure in capturable vars
|
||||
std::vector<array> input_arrays;
|
||||
std::vector<mx::array> input_arrays;
|
||||
nb::object input_structure;
|
||||
auto full_args = nb::make_tuple(args, kwargs);
|
||||
std::tie(input_arrays, input_structure) =
|
||||
@@ -864,7 +867,7 @@ class PyCustomFunction {
|
||||
|
||||
// Make a function that calls fun_ in the forward pass and vjp_ in the
|
||||
// backward pass. Then call it immediately and return the results.
|
||||
auto f = custom_function(
|
||||
auto f = mx::custom_function(
|
||||
InnerFunction(fun_, input_structure, output_structure),
|
||||
make_vjp_function(input_structure, output_structure),
|
||||
make_jvp_function(input_structure),
|
||||
@@ -1044,7 +1047,7 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
std::vector<mx::array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
eval(arrays);
|
||||
@@ -1064,7 +1067,7 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"async_eval",
|
||||
[](const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
std::vector<mx::array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
async_eval(arrays);
|
||||
@@ -1100,14 +1103,14 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"jvp",
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<mx::array>& primals) {
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
if (nb::isinstance<mx::array>(out)) {
|
||||
return std::vector<mx::array>{nb::cast<mx::array>(out)};
|
||||
} else {
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<mx::array>>(out);
|
||||
}
|
||||
};
|
||||
return jvp(vfun, primals, tangents);
|
||||
@@ -1139,14 +1142,14 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"vjp",
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<mx::array>& primals) {
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
if (nb::isinstance<mx::array>(out)) {
|
||||
return std::vector<mx::array>{nb::cast<mx::array>(out)};
|
||||
} else {
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<mx::array>>(out);
|
||||
}
|
||||
};
|
||||
return vjp(vfun, primals, cotangents);
|
||||
@@ -1312,7 +1315,7 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
@@ -1399,14 +1402,14 @@ void init_transforms(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"disable_compile",
|
||||
&disable_compile,
|
||||
&mx::disable_compile,
|
||||
R"pbdoc(
|
||||
Globally disable compilation. Setting the environment variable
|
||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
&mx::enable_compile,
|
||||
R"pbdoc(
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
@@ -1420,6 +1423,6 @@ void init_transforms(nb::module_& m) {
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() {
|
||||
tree_cache().clear();
|
||||
detail::compile_clear_cache();
|
||||
mx::detail::compile_clear_cache();
|
||||
}));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user