mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Async eval (#972)
This commit is contained in:
		@@ -595,6 +595,14 @@ class PyCheckpointedFun {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void init_transforms(nb::module_& m) {
 | 
			
		||||
  nb::class_<std::shared_future<void>>(
 | 
			
		||||
      m,
 | 
			
		||||
      "Synchronizer",
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      A synchronization object returned by :func:`async_eval`.
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def("wait", [](const std::shared_future<void>& f) { f.wait(); });
 | 
			
		||||
 | 
			
		||||
  m.def(
 | 
			
		||||
      "eval",
 | 
			
		||||
      [](const nb::args& args) {
 | 
			
		||||
@@ -615,6 +623,38 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
              :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
 | 
			
		||||
              arrays are ignored.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "async_eval",
 | 
			
		||||
      [](const nb::args& args) {
 | 
			
		||||
        std::vector<array> arrays = tree_flatten(args, false);
 | 
			
		||||
        {
 | 
			
		||||
          nb::gil_scoped_release nogil;
 | 
			
		||||
          return async_eval(arrays);
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::sig("def async_eval(*args) -> Synchronizer"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Asynchronously evaluate an :class:`array` or tree of :class:`array`.
 | 
			
		||||
 | 
			
		||||
        .. warning::
 | 
			
		||||
 | 
			
		||||
          You must call ``wait`` on the returned synchronization object before
 | 
			
		||||
          using any arrays that are asynchronously evaluated.
 | 
			
		||||
 | 
			
		||||
        .. note::
 | 
			
		||||
 | 
			
		||||
          This is an experimental API and may change in future versions.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            *args (arrays or trees of arrays): Each argument can be a single array
 | 
			
		||||
              or a tree of arrays. If a tree is given the nodes can be a Python
 | 
			
		||||
              :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
 | 
			
		||||
              arrays are ignored.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Synchronizer: A synchronization object.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "jvp",
 | 
			
		||||
      [](const nb::callable& fun,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user