mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Fix stub generation, change graph exporting for arrows to go to outputs (#455)
This commit is contained in:
		@@ -532,25 +532,12 @@ void init_array(py::module_& m) {
 | 
			
		||||
  m.attr("bfloat16") = py::cast(bfloat16);
 | 
			
		||||
  m.attr("complex64") = py::cast(complex64);
 | 
			
		||||
 | 
			
		||||
  py::class_<ArrayAt>(
 | 
			
		||||
  auto array_at_class = py::class_<ArrayAt>(
 | 
			
		||||
      m,
 | 
			
		||||
      "_ArrayAt",
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      A helper object to apply updates at specific indices.
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          py::init([](const array& x) { return ArrayAt(x); }),
 | 
			
		||||
          "x"_a,
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
          __init__(self, x: array)
 | 
			
		||||
        )pbdoc")
 | 
			
		||||
      .def("__getitem__", &ArrayAt::set_indices, "indices"_a)
 | 
			
		||||
      .def("add", &ArrayAt::add, "value"_a)
 | 
			
		||||
      .def("subtract", &ArrayAt::subtract, "value"_a)
 | 
			
		||||
      .def("multiply", &ArrayAt::multiply, "value"_a)
 | 
			
		||||
      .def("divide", &ArrayAt::divide, "value"_a)
 | 
			
		||||
      .def("maximum", &ArrayAt::maximum, "value"_a)
 | 
			
		||||
      .def("minimum", &ArrayAt::minimum, "value"_a);
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  auto array_class = py::class_<array>(
 | 
			
		||||
      m,
 | 
			
		||||
@@ -573,6 +560,21 @@ void init_array(py::module_& m) {
 | 
			
		||||
          )pbdoc");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  array_at_class
 | 
			
		||||
      .def(
 | 
			
		||||
          py::init([](const array& x) { return ArrayAt(x); }),
 | 
			
		||||
          "x"_a,
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
          __init__(self, x: array)
 | 
			
		||||
        )pbdoc")
 | 
			
		||||
      .def("__getitem__", &ArrayAt::set_indices, "indices"_a)
 | 
			
		||||
      .def("add", &ArrayAt::add, "value"_a)
 | 
			
		||||
      .def("subtract", &ArrayAt::subtract, "value"_a)
 | 
			
		||||
      .def("multiply", &ArrayAt::multiply, "value"_a)
 | 
			
		||||
      .def("divide", &ArrayAt::divide, "value"_a)
 | 
			
		||||
      .def("maximum", &ArrayAt::maximum, "value"_a)
 | 
			
		||||
      .def("minimum", &ArrayAt::minimum, "value"_a);
 | 
			
		||||
 | 
			
		||||
  array_class
 | 
			
		||||
      .def_buffer([](array& a) {
 | 
			
		||||
        // Eval if not already evaled
 | 
			
		||||
@@ -680,17 +682,17 @@ void init_array(py::module_& m) {
 | 
			
		||||
 | 
			
		||||
               * - array.at syntax
 | 
			
		||||
                 - In-place syntax
 | 
			
		||||
               * - ``x = x.at[idx].add(y)`` 
 | 
			
		||||
               * - ``x = x.at[idx].add(y)``
 | 
			
		||||
                 - ``x[idx] += y``
 | 
			
		||||
               * - ``x = x.at[idx].subtract(y)`` 
 | 
			
		||||
               * - ``x = x.at[idx].subtract(y)``
 | 
			
		||||
                 - ``x[idx] -= y``
 | 
			
		||||
               * - ``x = x.at[idx].multiply(y)`` 
 | 
			
		||||
               * - ``x = x.at[idx].multiply(y)``
 | 
			
		||||
                 - ``x[idx] *= y``
 | 
			
		||||
               * - ``x = x.at[idx].divide(y)`` 
 | 
			
		||||
               * - ``x = x.at[idx].divide(y)``
 | 
			
		||||
                 - ``x[idx] /= y``
 | 
			
		||||
               * - ``x = x.at[idx].maximum(y)`` 
 | 
			
		||||
               * - ``x = x.at[idx].maximum(y)``
 | 
			
		||||
                 - ``x[idx] = mx.maximum(x[idx], y)``
 | 
			
		||||
               * - ``x = x.at[idx].minimum(y)`` 
 | 
			
		||||
               * - ``x = x.at[idx].minimum(y)``
 | 
			
		||||
                - ``x[idx] = mx.minimum(x[idx], y)``
 | 
			
		||||
          )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,7 @@ using namespace py::literals;
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
void init_device(py::module_& m) {
 | 
			
		||||
  auto device_class = py::class_<Device>(m, "Device");
 | 
			
		||||
  py::enum_<Device::DeviceType>(m, "DeviceType")
 | 
			
		||||
      .value("cpu", Device::DeviceType::cpu)
 | 
			
		||||
      .value("gpu", Device::DeviceType::gpu)
 | 
			
		||||
@@ -23,8 +24,7 @@ void init_device(py::module_& m) {
 | 
			
		||||
          },
 | 
			
		||||
          py::prepend());
 | 
			
		||||
 | 
			
		||||
  py::class_<Device>(m, "Device")
 | 
			
		||||
      .def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
 | 
			
		||||
  device_class.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
 | 
			
		||||
      .def_readonly("type", &Device::type)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__repr__",
 | 
			
		||||
 
 | 
			
		||||
@@ -438,6 +438,9 @@ auto py_vmap(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void init_transforms(py::module_& m) {
 | 
			
		||||
  py::options options;
 | 
			
		||||
  options.disable_function_signatures();
 | 
			
		||||
 | 
			
		||||
  m.def(
 | 
			
		||||
      "eval",
 | 
			
		||||
      [](const py::args& args) {
 | 
			
		||||
@@ -445,6 +448,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
        eval(arrays);
 | 
			
		||||
      },
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        eval(*args) -> None
 | 
			
		||||
 | 
			
		||||
        Evaluate an :class:`array` or tree of :class:`array`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@@ -476,6 +481,9 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
      "primals"_a,
 | 
			
		||||
      "tangents"_a,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        Compute the Jacobian-vector product.
 | 
			
		||||
 | 
			
		||||
        This computes the product of the Jacobian of a function ``fun`` evaluated
 | 
			
		||||
@@ -517,6 +525,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
      "primals"_a,
 | 
			
		||||
      "cotangents"_a,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
 | 
			
		||||
 | 
			
		||||
        Compute the vector-Jacobian product.
 | 
			
		||||
 | 
			
		||||
        Computes the product of the ``cotangents`` with the Jacobian of a
 | 
			
		||||
@@ -549,6 +559,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
      "argnums"_a = std::nullopt,
 | 
			
		||||
      "argnames"_a = std::vector<std::string>{},
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
 | 
			
		||||
 | 
			
		||||
        Returns a function which computes the value and gradient of ``fun``.
 | 
			
		||||
 | 
			
		||||
        The function passed to :func:`value_and_grad` should return either
 | 
			
		||||
@@ -615,6 +627,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
      "argnums"_a = std::nullopt,
 | 
			
		||||
      "argnames"_a = std::vector<std::string>{},
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
 | 
			
		||||
 | 
			
		||||
        Returns a function which computes the gradient of ``fun``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@@ -645,6 +659,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
      "in_axes"_a = 0,
 | 
			
		||||
      "out_axes"_a = 0,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
 | 
			
		||||
 | 
			
		||||
        Returns a vectorized version of ``fun``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@@ -670,6 +686,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
        simplify(arrays);
 | 
			
		||||
      },
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        simplify(*args) -> None
 | 
			
		||||
 | 
			
		||||
        Simplify the graph that computes the arrays.
 | 
			
		||||
 | 
			
		||||
        Run a few fast graph simplification operations to reuse computation and
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user