mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add Tensordot op (#344)
This commit is contained in:
		@@ -3194,4 +3194,44 @@ void init_ops(py::module_& m) {
 | 
			
		||||
        Returns:
 | 
			
		||||
          result (array): The dequantized version of ``w``
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "tensordot",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
         const array& b,
 | 
			
		||||
         const std::variant<int, std::vector<std::vector<int>>>& dims,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        if (auto pv = std::get_if<int>(&dims); pv) {
 | 
			
		||||
          return tensordot(a, b, *pv, s);
 | 
			
		||||
        } else {
 | 
			
		||||
          auto x = std::get<std::vector<std::vector<int>>>(dims);
 | 
			
		||||
          if (x.size() != 2) {
 | 
			
		||||
            throw std::invalid_argument(
 | 
			
		||||
                "[tensordot] dims must be a list of two lists.");
 | 
			
		||||
          }
 | 
			
		||||
          return tensordot(a, b, {x[0], x[1]}, s);
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "a"_a,
 | 
			
		||||
      "b"_a,
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "dims"_a = 2,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      "stream"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
 | 
			
		||||
 | 
			
		||||
        Compute the tensor dot product along the specified axes.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          a (array): Input array
 | 
			
		||||
          b (array): Input array
 | 
			
		||||
          dims (int or list(list(int)), optional): The number of dimensions to
 | 
			
		||||
            sum over. If an integer is provided, then sum over the last
 | 
			
		||||
            ``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
 | 
			
		||||
            ``b``. If a list of lists is provided, then sum over the
 | 
			
		||||
            corresponding dimensions of ``a`` and ``b``. (default: 2)
 | 
			
		||||
        
 | 
			
		||||
        Returns:
 | 
			
		||||
          result (array): The tensor dot product.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user