mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
		@@ -4065,6 +4065,45 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The extracted diagonal or the constructed diagonal matrix.
 | 
			
		||||
        )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "trace",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
         int offset,
 | 
			
		||||
         int axis1,
 | 
			
		||||
         int axis2,
 | 
			
		||||
         std::optional<Dtype> dtype,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        if (!dtype.has_value()) {
 | 
			
		||||
          return trace(a, offset, axis1, axis2, s);
 | 
			
		||||
        }
 | 
			
		||||
        return trace(a, offset, axis1, axis2, dtype.value(), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "offset"_a = 0,
 | 
			
		||||
      "axis1"_a = 0,
 | 
			
		||||
      "axis2"_a = 1,
 | 
			
		||||
      "dtype"_a = nb::none(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Return the sum along a specified diagonal in the given array.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          a (array): Input array
 | 
			
		||||
          offset (int, optional): Offset of the diagonal from the main diagonal.
 | 
			
		||||
            Can be positive or negative. Default: ``0``.
 | 
			
		||||
          axis1 (int, optional): The first axis of the 2-D sub-arrays from which
 | 
			
		||||
              the diagonals should be taken. Default: ``0``.
 | 
			
		||||
          axis2 (int, optional): The second axis of the 2-D sub-arrays from which
 | 
			
		||||
              the diagonals should be taken. Default: ``1``.
 | 
			
		||||
          dtype (Dtype, optional): Data type of the output array. If
 | 
			
		||||
              unspecified the output type is inferred from the input array.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: Sum of specified diagonal.
 | 
			
		||||
        )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "atleast_1d",
 | 
			
		||||
      [](const nb::args& arys, StreamOrDevice s) -> nb::object {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user