mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	nice tensordot for mlx c (#782)
This commit is contained in:
		@@ -3555,17 +3555,17 @@ void init_ops(py::module_& m) {
 | 
			
		||||
      "tensordot",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
         const array& b,
 | 
			
		||||
         const std::variant<int, std::vector<std::vector<int>>>& dims,
 | 
			
		||||
         const std::variant<int, std::vector<std::vector<int>>>& axes,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        if (auto pv = std::get_if<int>(&dims); pv) {
 | 
			
		||||
        if (auto pv = std::get_if<int>(&axes); pv) {
 | 
			
		||||
          return tensordot(a, b, *pv, s);
 | 
			
		||||
        } else {
 | 
			
		||||
          auto x = std::get<std::vector<std::vector<int>>>(dims);
 | 
			
		||||
          auto& x = std::get<std::vector<std::vector<int>>>(axes);
 | 
			
		||||
          if (x.size() != 2) {
 | 
			
		||||
            throw std::invalid_argument(
 | 
			
		||||
                "[tensordot] dims must be a list of two lists.");
 | 
			
		||||
                "[tensordot] axes must be a list of two lists.");
 | 
			
		||||
          }
 | 
			
		||||
          return tensordot(a, b, {x[0], x[1]}, s);
 | 
			
		||||
          return tensordot(a, b, x[0], x[1], s);
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "a"_a,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user