mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Custom transforms (#1246)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							a3c287354f
						
					
				
				
					commit
					5c1fa64fb0
				
			@@ -593,7 +593,454 @@ class PyCheckpointedFun {
 | 
			
		||||
  nb::callable fun_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * PyCustomFunction is the class that implements the python decorator
 | 
			
		||||
 * `mx.custom_function`.
 | 
			
		||||
 *
 | 
			
		||||
 * It implements a callable that instead of simply calling `fun` it creates a
 | 
			
		||||
 * CustomTransforms primitive via the `custom_function` C++ op which allows us
 | 
			
		||||
 * to redefine the vjp, jvp and vmap transformations.
 | 
			
		||||
 *
 | 
			
		||||
 * The implementation is verbose due to explicit handling of the destruction of
 | 
			
		||||
 * various python objects to make sure that there is no double-free and that
 | 
			
		||||
 * all of them are deleted while under GIL.
 | 
			
		||||
 *
 | 
			
		||||
 * Namely, for every one of the functions passed to the C++ `custom_function`
 | 
			
		||||
 * we create a callable struct that holds the following python objects (when
 | 
			
		||||
 * needed).
 | 
			
		||||
 *
 | 
			
		||||
 *    - An nb::callable which holds the passed function or transform
 | 
			
		||||
 *    - An nb::object holding input structure, namely the `(args, kwargs)`
 | 
			
		||||
 *      passed to the function in order to be able to recreate the arguments
 | 
			
		||||
 *      from the input arrays.
 | 
			
		||||
 *    - A std::shared_ptr<nb::object> holding the output structure name the
 | 
			
		||||
 *      structure of the return value of `fun`. It is a shared_ptr so that it
 | 
			
		||||
 *      can be set when the function is called and then used in the `vjp`
 | 
			
		||||
 *      transform. We delete the object only when the shared_ptr is about to be
 | 
			
		||||
 *      deleted see `output_structure_.use_count() == 1` to make sure that the
 | 
			
		||||
 *      object is deleted under GIL.
 | 
			
		||||
 */
 | 
			
		||||
class PyCustomFunction {
 | 
			
		||||
 public:
 | 
			
		||||
  PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
 | 
			
		||||
  ~PyCustomFunction() {
 | 
			
		||||
    nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
    fun_.release().dec_ref();
 | 
			
		||||
    if (vjp_fun_.has_value()) {
 | 
			
		||||
      (*vjp_fun_).release().dec_ref();
 | 
			
		||||
    }
 | 
			
		||||
    if (jvp_fun_.has_value()) {
 | 
			
		||||
      (*jvp_fun_).release().dec_ref();
 | 
			
		||||
    }
 | 
			
		||||
    if (vmap_fun_.has_value()) {
 | 
			
		||||
      (*vmap_fun_).release().dec_ref();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  struct InnerFunction {
 | 
			
		||||
    nb::callable fun_;
 | 
			
		||||
    nb::object input_structure_;
 | 
			
		||||
    std::shared_ptr<nb::object> output_structure_;
 | 
			
		||||
 | 
			
		||||
    InnerFunction(
 | 
			
		||||
        nb::callable fun,
 | 
			
		||||
        nb::object input_structure,
 | 
			
		||||
        std::shared_ptr<nb::object> output_structure)
 | 
			
		||||
        : fun_(std::move(fun)),
 | 
			
		||||
          input_structure_(std::move(input_structure)),
 | 
			
		||||
          output_structure_(std::move(output_structure)) {}
 | 
			
		||||
    ~InnerFunction() {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      fun_.release().dec_ref();
 | 
			
		||||
      input_structure_.release().dec_ref();
 | 
			
		||||
      if (output_structure_.use_count() == 1) {
 | 
			
		||||
        output_structure_->release().dec_ref();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<array> operator()(const std::vector<array>& inputs) {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      auto new_inputs = nb::cast<nb::tuple>(
 | 
			
		||||
          tree_unflatten_from_structure(input_structure_, inputs));
 | 
			
		||||
      std::vector<array> outputs;
 | 
			
		||||
      std::tie(outputs, *output_structure_) =
 | 
			
		||||
          tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
 | 
			
		||||
      return outputs;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  struct InnerVJPFunction {
 | 
			
		||||
    nb::callable vjp_fun_;
 | 
			
		||||
    nb::object input_structure_;
 | 
			
		||||
    std::shared_ptr<nb::object> output_structure_;
 | 
			
		||||
 | 
			
		||||
    InnerVJPFunction(
 | 
			
		||||
        nb::callable vjp_fun,
 | 
			
		||||
        nb::object input_structure,
 | 
			
		||||
        std::shared_ptr<nb::object> output_structure)
 | 
			
		||||
        : vjp_fun_(std::move(vjp_fun)),
 | 
			
		||||
          input_structure_(std::move(input_structure)),
 | 
			
		||||
          output_structure_(std::move(output_structure)) {}
 | 
			
		||||
    ~InnerVJPFunction() {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      vjp_fun_.release().dec_ref();
 | 
			
		||||
      input_structure_.release().dec_ref();
 | 
			
		||||
      if (output_structure_.use_count() == 1) {
 | 
			
		||||
        output_structure_->release().dec_ref();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<array> operator()(
 | 
			
		||||
        const std::vector<array>& primals,
 | 
			
		||||
        const std::vector<array>& cotangents,
 | 
			
		||||
        const std::vector<array>& outputs) {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      auto new_inputs = nb::cast<nb::tuple>(
 | 
			
		||||
          tree_unflatten_from_structure(input_structure_, primals));
 | 
			
		||||
      auto args = nb::cast<nb::tuple>(new_inputs[0]);
 | 
			
		||||
      auto new_cotangents =
 | 
			
		||||
          tree_unflatten_from_structure(*output_structure_, cotangents);
 | 
			
		||||
      auto new_outputs =
 | 
			
		||||
          tree_unflatten_from_structure(*output_structure_, outputs);
 | 
			
		||||
 | 
			
		||||
      if (args.size() == 1) {
 | 
			
		||||
        return tree_flatten(
 | 
			
		||||
            vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]),
 | 
			
		||||
            false);
 | 
			
		||||
      } else {
 | 
			
		||||
        return tree_flatten(
 | 
			
		||||
            vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]),
 | 
			
		||||
            false);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  struct InnerJVPFunction {
 | 
			
		||||
    nb::callable jvp_fun_;
 | 
			
		||||
    nb::object input_structure_;
 | 
			
		||||
 | 
			
		||||
    InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure)
 | 
			
		||||
        : jvp_fun_(std::move(jvp_fun)),
 | 
			
		||||
          input_structure_(std::move(input_structure)) {}
 | 
			
		||||
    ~InnerJVPFunction() {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      jvp_fun_.release().dec_ref();
 | 
			
		||||
      input_structure_.release().dec_ref();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<array> operator()(
 | 
			
		||||
        const std::vector<array>& primals,
 | 
			
		||||
        const std::vector<array>& tangents,
 | 
			
		||||
        const std::vector<int>& argnums) {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      auto new_inputs = nb::cast<nb::tuple>(
 | 
			
		||||
          tree_unflatten_from_structure(input_structure_, primals));
 | 
			
		||||
      auto args = nb::cast<nb::tuple>(new_inputs[0]);
 | 
			
		||||
      auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
 | 
			
		||||
      if (kwargs.size() > 0) {
 | 
			
		||||
        throw std::invalid_argument(
 | 
			
		||||
            "[custom jvp] Function should only accept positional arguments");
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Make a new pytree which has tangents or None when a tangent is not
 | 
			
		||||
      // available.
 | 
			
		||||
      std::vector<bool> have_tangents(primals.size(), false);
 | 
			
		||||
      for (auto arg : argnums) {
 | 
			
		||||
        have_tangents[arg] = true;
 | 
			
		||||
      }
 | 
			
		||||
      int array_index = 0;
 | 
			
		||||
      int tangent_index = 0;
 | 
			
		||||
      auto new_tangents =
 | 
			
		||||
          nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
 | 
			
		||||
            if (nb::isinstance<array>(element) &&
 | 
			
		||||
                have_tangents[array_index++]) {
 | 
			
		||||
              return nb::cast(tangents[tangent_index++]);
 | 
			
		||||
            } else {
 | 
			
		||||
              return nb::none();
 | 
			
		||||
            }
 | 
			
		||||
          }));
 | 
			
		||||
 | 
			
		||||
      if (args.size() == 1) {
 | 
			
		||||
        return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false);
 | 
			
		||||
      } else {
 | 
			
		||||
        return tree_flatten(jvp_fun_(args, new_tangents), false);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  struct InnerVmapFunction {
 | 
			
		||||
    nb::callable vmap_fun_;
 | 
			
		||||
    nb::object input_structure_;
 | 
			
		||||
 | 
			
		||||
    InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure)
 | 
			
		||||
        : vmap_fun_(std::move(vmap_fun)),
 | 
			
		||||
          input_structure_(std::move(input_structure)) {}
 | 
			
		||||
    ~InnerVmapFunction() {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      vmap_fun_.release().dec_ref();
 | 
			
		||||
      input_structure_.release().dec_ref();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::pair<std::vector<array>, std::vector<int>> operator()(
 | 
			
		||||
        const std::vector<array>& inputs,
 | 
			
		||||
        const std::vector<int>& axes) {
 | 
			
		||||
      nb::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
      auto new_inputs = nb::cast<nb::tuple>(
 | 
			
		||||
          tree_unflatten_from_structure(input_structure_, inputs));
 | 
			
		||||
      auto args = nb::cast<nb::tuple>(new_inputs[0]);
 | 
			
		||||
      auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
 | 
			
		||||
      if (kwargs.size() > 0) {
 | 
			
		||||
        throw std::invalid_argument(
 | 
			
		||||
            "[custom vmap] Function should only accept positional arguments");
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      int arr_index;
 | 
			
		||||
      auto new_axes =
 | 
			
		||||
          nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
 | 
			
		||||
            int axis = axes[arr_index++];
 | 
			
		||||
            if (nb::isinstance<array>(element) && axis >= 0) {
 | 
			
		||||
              return nb::cast(axis);
 | 
			
		||||
            } else {
 | 
			
		||||
              return nb::none();
 | 
			
		||||
            }
 | 
			
		||||
          }));
 | 
			
		||||
 | 
			
		||||
      nb::object result;
 | 
			
		||||
      if (args.size() == 1) {
 | 
			
		||||
        result = vmap_fun_(args[0], new_axes[0]);
 | 
			
		||||
      } else {
 | 
			
		||||
        result = vmap_fun_(args, new_axes);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (!nb::isinstance<nb::tuple>(result)) {
 | 
			
		||||
        throw std::invalid_argument(
 | 
			
		||||
            "[custom vmap] Vmap function should return a tuple with 2 items.");
 | 
			
		||||
      }
 | 
			
		||||
      nb::tuple result_tuple = nb::cast<nb::tuple>(result);
 | 
			
		||||
      if (result_tuple.size() != 2) {
 | 
			
		||||
        throw std::invalid_argument(
 | 
			
		||||
            "[custom vmap] Vmap function should return a tuple with 2 items.");
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      std::vector<array> outputs;
 | 
			
		||||
      std::vector<int> output_axes;
 | 
			
		||||
      tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
 | 
			
		||||
        if (nb::isinstance<array>(objects[0])) {
 | 
			
		||||
          outputs.push_back(nb::cast<array>(objects[0]));
 | 
			
		||||
          output_axes.push_back(
 | 
			
		||||
              objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
 | 
			
		||||
        }
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
      return {outputs, output_axes};
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
 | 
			
		||||
    if (!vjp_fun_.has_value() && !jvp_fun_.has_value() &&
 | 
			
		||||
        !vmap_fun_.has_value()) {
 | 
			
		||||
      return fun_(*args, **kwargs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Extract the inputs and their structure in capturable vars
 | 
			
		||||
    std::vector<array> input_arrays;
 | 
			
		||||
    nb::object input_structure;
 | 
			
		||||
    auto full_args = nb::make_tuple(args, kwargs);
 | 
			
		||||
    std::tie(input_arrays, input_structure) =
 | 
			
		||||
        tree_flatten_with_structure(full_args, false);
 | 
			
		||||
 | 
			
		||||
    // The output structure will be stored here to be used in the custom vjp
 | 
			
		||||
    // function
 | 
			
		||||
    auto output_structure = std::make_shared<nb::object>();
 | 
			
		||||
 | 
			
		||||
    // Make a function that calls fun_ in the forward pass and vjp_ in the
 | 
			
		||||
    // backward pass. Then call it immediately and return the results.
 | 
			
		||||
    auto f = custom_function(
 | 
			
		||||
        InnerFunction(fun_, input_structure, output_structure),
 | 
			
		||||
        make_vjp_function(input_structure, output_structure),
 | 
			
		||||
        make_jvp_function(input_structure),
 | 
			
		||||
        make_vmap_function(input_structure));
 | 
			
		||||
 | 
			
		||||
    auto outputs = f(input_arrays);
 | 
			
		||||
    return tree_unflatten_from_structure(*output_structure, outputs);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  PyCustomFunction& set_vjp(nb::callable vjp_fun) {
 | 
			
		||||
    vjp_fun_ = vjp_fun;
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  PyCustomFunction& set_jvp(nb::callable jvp_fun) {
 | 
			
		||||
    jvp_fun_ = jvp_fun;
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  PyCustomFunction& set_vmap(nb::callable vmap_fun) {
 | 
			
		||||
    vmap_fun_ = vmap_fun;
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::optional<InnerVJPFunction> make_vjp_function(
 | 
			
		||||
      nb::object input_structure,
 | 
			
		||||
      std::shared_ptr<nb::object> output_structure) {
 | 
			
		||||
    if (!vjp_fun_.has_value()) {
 | 
			
		||||
      return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return InnerVJPFunction(*vjp_fun_, input_structure, output_structure);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::optional<InnerJVPFunction> make_jvp_function(
 | 
			
		||||
      nb::object input_structure) {
 | 
			
		||||
    if (!jvp_fun_.has_value()) {
 | 
			
		||||
      return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return InnerJVPFunction(*jvp_fun_, input_structure);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::optional<InnerVmapFunction> make_vmap_function(
 | 
			
		||||
      nb::object input_structure) {
 | 
			
		||||
    if (!vmap_fun_.has_value()) {
 | 
			
		||||
      return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return InnerVmapFunction(*vmap_fun_, input_structure);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  nb::callable fun_;
 | 
			
		||||
  std::optional<nb::callable> vjp_fun_;
 | 
			
		||||
  std::optional<nb::callable> jvp_fun_;
 | 
			
		||||
  std::optional<nb::callable> vmap_fun_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void init_transforms(nb::module_& m) {
 | 
			
		||||
  nb::class_<PyCustomFunction>(
 | 
			
		||||
      m,
 | 
			
		||||
      "custom_function",
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      Set up a function for custom gradient and vmap definitions.
 | 
			
		||||
 | 
			
		||||
      This class is meant to be used as a function decorator. Instances are
 | 
			
		||||
      callables that behave identically to the wrapped function. However, when
 | 
			
		||||
      a function transformation is used (e.g. computing gradients using
 | 
			
		||||
      :func:`value_and_grad`) then the functions defined via :method:`vjp`,
 | 
			
		||||
      :method:`jvp` and :method:`vmap` are used instead of the default
 | 
			
		||||
      transformation.
 | 
			
		||||
 | 
			
		||||
      Note, all custom transformations are optional. Undefined transformations
 | 
			
		||||
      fall back to the default behaviour.
 | 
			
		||||
 | 
			
		||||
      Example usage:
 | 
			
		||||
 | 
			
		||||
      .. code-block:: python
 | 
			
		||||
 | 
			
		||||
          import mlx.core as mx
 | 
			
		||||
 | 
			
		||||
          @mx.custom_function
 | 
			
		||||
          def f(x, y):
 | 
			
		||||
              return mx.sin(x) * y
 | 
			
		||||
 | 
			
		||||
          @f.vjp
 | 
			
		||||
          def f_vjp(primals, cotangent, output):
 | 
			
		||||
              x, y = primals
 | 
			
		||||
              return cotan * mx.cos(x) * y, cotan * mx.sin(x)
 | 
			
		||||
 | 
			
		||||
          @f.jvp
 | 
			
		||||
          def f_jvp(primals, tangents):
 | 
			
		||||
            x, y = primals
 | 
			
		||||
            dx, dy = tangents
 | 
			
		||||
            return dx * mx.cos(x) * y + dy * mx.sin(x)
 | 
			
		||||
 | 
			
		||||
          @f.vmap
 | 
			
		||||
          def f_vmap(inputs, axes):
 | 
			
		||||
            x, y = inputs
 | 
			
		||||
            ax, ay = axes
 | 
			
		||||
            if ay != ax and ax is not None:
 | 
			
		||||
                y = y.swapaxes(ay, ax)
 | 
			
		||||
            return mx.sin(x) * y, (ax or ay)
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          nb::init<nb::callable>(),
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def __init__(self, f: callable)"))
 | 
			
		||||
      .def("__call__", &PyCustomFunction::call_impl)
 | 
			
		||||
      .def(
 | 
			
		||||
          "vjp",
 | 
			
		||||
          &PyCustomFunction::set_vjp,
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def vjp(self, f_vjp: callable)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Define a custom vjp for the wrapped function.
 | 
			
		||||
 | 
			
		||||
            The vjp function takes three arguments:
 | 
			
		||||
 | 
			
		||||
            - *primals*: A pytree that contains all the positional arguments to
 | 
			
		||||
              the function. It could be a single array, a tuple of arrays or a
 | 
			
		||||
              full blown tuple of dicts of arrays etc.
 | 
			
		||||
            - *cotangents*: A pytree that matches the structure of the output
 | 
			
		||||
              but contains the cotangents (usually the gradients of the loss
 | 
			
		||||
              function with respect to the outputs).
 | 
			
		||||
            - *outputs*: The outputs of the function to be used to avoid
 | 
			
		||||
              recomputing them for the gradient computation.
 | 
			
		||||
 | 
			
		||||
            The vjp function should return the same pytree structure as the
 | 
			
		||||
            primals but containing the corresponding computed cotangents.
 | 
			
		||||
          )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          "jvp",
 | 
			
		||||
          &PyCustomFunction::set_jvp,
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def jvp(self, f_jvp: callable)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Define a custom jvp for the wrapped function.
 | 
			
		||||
 | 
			
		||||
            The jvp function takes two arguments:
 | 
			
		||||
 | 
			
		||||
            - *primals*: A pytree that contains all the positional arguments to
 | 
			
		||||
              the function. It could be a single array, a tuple of arrays or a
 | 
			
		||||
              full blown tuple of dicts of arrays etc.
 | 
			
		||||
            - *tangents*: A pytree that matches the structure of the inputs but
 | 
			
		||||
              instead contains the gradients wrt to each input. Tangents could
 | 
			
		||||
              be ``None`` if some inputs don't have an associated gradient.
 | 
			
		||||
 | 
			
		||||
            The jvp function should return the same pytree structure as the
 | 
			
		||||
            outputs of the function but containing the tangents.
 | 
			
		||||
          )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          "vmap",
 | 
			
		||||
          &PyCustomFunction::set_vmap,
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def vmap(self, f_vmap: callable)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Define a custom vectorization transformation for the wrapped function.
 | 
			
		||||
 | 
			
		||||
            The vmap function takes two arguments:
 | 
			
		||||
 | 
			
		||||
            - *inputs*: A pytree that contains all the positional arguments to
 | 
			
		||||
              the function. It could be a single array, a tuple of arrays or a
 | 
			
		||||
              full blown tuple of dicts of arrays etc.
 | 
			
		||||
            - *axes*: A pytree that matches the structure of the inputs but
 | 
			
		||||
              instead contains the vectorization axis for each input or
 | 
			
		||||
              ``None`` if an input is not vectorized.
 | 
			
		||||
 | 
			
		||||
            The vmap function should return the outputs of the original
 | 
			
		||||
            function but vectorized over the provided axes. It should also
 | 
			
		||||
            return a pytree with the vectorization axes of each output. If some
 | 
			
		||||
            outputs are no longer vectorized, then their vectorization axis
 | 
			
		||||
            should be ``None``.
 | 
			
		||||
          )pbdoc");
 | 
			
		||||
 | 
			
		||||
  m.def(
 | 
			
		||||
      "eval",
 | 
			
		||||
      [](const nb::args& args) {
 | 
			
		||||
@@ -888,8 +1335,10 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
         const nb::object& outputs,
 | 
			
		||||
         bool shapeless) {
 | 
			
		||||
        //  Try to get the name
 | 
			
		||||
        auto n = fun.attr("__name__");
 | 
			
		||||
        auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
 | 
			
		||||
        auto n =
 | 
			
		||||
            nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
 | 
			
		||||
        auto name = n.is_none() ? "compiled"
 | 
			
		||||
                                : nb::cast<std::string>(fun.attr("__name__"));
 | 
			
		||||
 | 
			
		||||
        // Try to get the signature
 | 
			
		||||
        std::ostringstream sig;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user