diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 1cbc5a987..7c1c17740 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -15,8 +15,8 @@ namespace mlx::core { struct NodeNamer { std::unordered_map names; - std::string get_name(uintptr_t id) { - auto it = names.find(id); + std::string get_name(const array& x) { + auto it = names.find(x.id()); if (it == names.end()) { // Get the next name in the sequence // [A, B, ..., Z, AA, AB, ...] @@ -27,15 +27,11 @@ struct NodeNamer { var_num = (var_num - 1) / 26; } std::string name(letters.rbegin(), letters.rend()); - names.insert({id, name}); + names.insert({x.id(), name}); return name; } return it->second; } - - std::string get_name(const array& x) { - return get_name(x.id()); - } }; void depth_first_traversal( @@ -124,15 +120,14 @@ void export_to_dot(std::ostream& os, const std::vector& outputs) { // Node for primitive if (x.has_primitive()) { os << "{ "; - os << namer.get_name(x.primitive_id()); + os << x.primitive_id(); os << " [label =\""; x.primitive().print(os); os << "\", shape=rectangle]"; os << "; }" << std::endl; // Arrows to primitive's inputs for (auto& a : x.inputs()) { - os << namer.get_name(x.primitive_id()) << " -> " - << namer.get_name(a) << std::endl; + os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl; } } @@ -145,8 +140,7 @@ void export_to_dot(std::ostream& os, const std::vector& outputs) { os << namer.get_name(a); os << "; }" << std::endl; if (x.has_primitive()) { - os << namer.get_name(a) << " -> " - << namer.get_name(x.primitive_id()) << std::endl; + os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl; } } }, diff --git a/python/src/array.cpp b/python/src/array.cpp index 26f6e68ac..f7dfea90a 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -532,25 +532,12 @@ void init_array(py::module_& m) { m.attr("bfloat16") = py::cast(bfloat16); m.attr("complex64") = py::cast(complex64); - py::class_( + auto array_at_class = py::class_( m, "_ArrayAt", R"pbdoc( A helper object to apply updates at specific indices. - )pbdoc") - .def( - py::init([](const array& x) { return ArrayAt(x); }), - "x"_a, - R"pbdoc( - __init__(self, x: array) - )pbdoc") - .def("__getitem__", &ArrayAt::set_indices, "indices"_a) - .def("add", &ArrayAt::add, "value"_a) - .def("subtract", &ArrayAt::subtract, "value"_a) - .def("multiply", &ArrayAt::multiply, "value"_a) - .def("divide", &ArrayAt::divide, "value"_a) - .def("maximum", &ArrayAt::maximum, "value"_a) - .def("minimum", &ArrayAt::minimum, "value"_a); + )pbdoc"); auto array_class = py::class_( m, @@ -573,6 +560,21 @@ void init_array(py::module_& m) { )pbdoc"); } + array_at_class + .def( + py::init([](const array& x) { return ArrayAt(x); }), + "x"_a, + R"pbdoc( + __init__(self, x: array) + )pbdoc") + .def("__getitem__", &ArrayAt::set_indices, "indices"_a) + .def("add", &ArrayAt::add, "value"_a) + .def("subtract", &ArrayAt::subtract, "value"_a) + .def("multiply", &ArrayAt::multiply, "value"_a) + .def("divide", &ArrayAt::divide, "value"_a) + .def("maximum", &ArrayAt::maximum, "value"_a) + .def("minimum", &ArrayAt::minimum, "value"_a); + array_class .def_buffer([](array& a) { // Eval if not already evaled @@ -680,17 +682,17 @@ void init_array(py::module_& m) { * - array.at syntax - In-place syntax - * - ``x = x.at[idx].add(y)`` + * - ``x = x.at[idx].add(y)`` - ``x[idx] += y`` - * - ``x = x.at[idx].subtract(y)`` + * - ``x = x.at[idx].subtract(y)`` - ``x[idx] -= y`` - * - ``x = x.at[idx].multiply(y)`` + * - ``x = x.at[idx].multiply(y)`` - ``x[idx] *= y`` - * - ``x = x.at[idx].divide(y)`` + * - ``x = x.at[idx].divide(y)`` - ``x[idx] /= y`` - * - ``x = x.at[idx].maximum(y)`` + * - ``x = x.at[idx].maximum(y)`` - ``x[idx] = mx.maximum(x[idx], y)`` - * - ``x = x.at[idx].minimum(y)`` + * - ``x = x.at[idx].minimum(y)`` - ``x[idx] = mx.minimum(x[idx], y)`` )pbdoc") .def( diff --git a/python/src/device.cpp b/python/src/device.cpp index 50438c139..8c36f0f85 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -12,6 +12,7 @@ using namespace py::literals; using namespace mlx::core; void init_device(py::module_& m) { + auto device_class = py::class_(m, "Device"); py::enum_(m, "DeviceType") .value("cpu", Device::DeviceType::cpu) .value("gpu", Device::DeviceType::gpu) @@ -23,8 +24,7 @@ void init_device(py::module_& m) { }, py::prepend()); - py::class_(m, "Device") - .def(py::init(), "type"_a, "index"_a = 0) + device_class.def(py::init(), "type"_a, "index"_a = 0) .def_readonly("type", &Device::type) .def( "__repr__", diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index a1afebef9..0d217a9f6 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -438,6 +438,9 @@ auto py_vmap( } void init_transforms(py::module_& m) { + py::options options; + options.disable_function_signatures(); + m.def( "eval", [](const py::args& args) { @@ -445,6 +448,8 @@ void init_transforms(py::module_& m) { eval(arrays); }, R"pbdoc( + eval(*args) -> None + Evaluate an :class:`array` or tree of :class:`array`. Args: @@ -476,6 +481,9 @@ void init_transforms(py::module_& m) { "primals"_a, "tangents"_a, R"pbdoc( + jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]] + + Compute the Jacobian-vector product. This computes the product of the Jacobian of a function ``fun`` evaluated @@ -517,6 +525,8 @@ void init_transforms(py::module_& m) { "primals"_a, "cotangents"_a, R"pbdoc( + vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]] + Compute the vector-Jacobian product. Computes the product of the ``cotangents`` with the Jacobian of a @@ -549,6 +559,8 @@ void init_transforms(py::module_& m) { "argnums"_a = std::nullopt, "argnames"_a = std::vector{}, R"pbdoc( + value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function + Returns a function which computes the value and gradient of ``fun``. The function passed to :func:`value_and_grad` should return either @@ -615,6 +627,8 @@ void init_transforms(py::module_& m) { "argnums"_a = std::nullopt, "argnames"_a = std::vector{}, R"pbdoc( + grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function + Returns a function which computes the gradient of ``fun``. Args: @@ -645,6 +659,8 @@ void init_transforms(py::module_& m) { "in_axes"_a = 0, "out_axes"_a = 0, R"pbdoc( + vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function + Returns a vectorized version of ``fun``. Args: @@ -670,6 +686,8 @@ void init_transforms(py::module_& m) { simplify(arrays); }, R"pbdoc( + simplify(*args) -> None + Simplify the graph that computes the arrays. Run a few fast graph simplification operations to reuse computation and diff --git a/setup.py b/setup.py index 23283860c..3c53165c9 100644 --- a/setup.py +++ b/setup.py @@ -135,18 +135,6 @@ class GenerateStubs(Command): def run(self) -> None: subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"]) - # Note, sed inplace on macos requires a backup prefix, delete the file after its generated - # this sed is needed to replace references from py::cpp_function to a generic Callable - subprocess.run( - [ - "sed", - "-i", - "''", - "s/cpp_function/typing.Callable/g", - "python/mlx/core/__init__.pyi", - ] - ) - subprocess.run(["rm", "python/mlx/core/__init__.pyi''"]) # Read the content of README.md