mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	nice tensordot for mlx c (#782)
This commit is contained in:
		
							
								
								
									
										43
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							| @@ -3190,42 +3190,41 @@ array dequantize( | ||||
| array tensordot( | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const int dims /* = 2 */, | ||||
|     const int axis /* = 2 */, | ||||
|     StreamOrDevice s /* = {} */ | ||||
| ) { | ||||
|   if (dims < 0) { | ||||
|   if (axis < 0) { | ||||
|     throw std::invalid_argument( | ||||
|         "[tensordot] dims must be greater or equal to 0."); | ||||
|         "[tensordot] axis must be greater or equal to 0."); | ||||
|   } | ||||
|   if (dims > std::min(a.ndim(), b.ndim())) { | ||||
|   if (axis > std::min(a.ndim(), b.ndim())) { | ||||
|     throw std::invalid_argument( | ||||
|         "[tensordot] dims must be less than the number of dimensions of a and b."); | ||||
|         "[tensordot] axis must be less than the number of dimensions of a and b."); | ||||
|   } | ||||
|   std::vector<int> adims; | ||||
|   std::vector<int> bdims; | ||||
|   for (int i = 0; i < dims; i++) { | ||||
|   for (int i = 0; i < axis; i++) { | ||||
|     bdims.emplace_back(i); | ||||
|     adims.emplace_back(-dims + i); | ||||
|     adims.emplace_back(i - axis); | ||||
|   } | ||||
|   return tensordot(a, b, {adims, bdims}, s); | ||||
|   return tensordot(a, b, {adims}, {bdims}, s); | ||||
| } | ||||
|  | ||||
| array tensordot( | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const std::pair<std::vector<int>, std::vector<int>>& dims, | ||||
|     StreamOrDevice s /* = {} */ | ||||
| ) { | ||||
|   if (dims.first.size() != dims.second.size()) { | ||||
|     throw std::invalid_argument( | ||||
|         "[tensordot] dims[0] and dims[1] must have the same number of dimensions."); | ||||
|     const std::vector<int>& axes_a, | ||||
|     const std::vector<int>& axes_b, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   if (axes_a.size() != axes_b.size()) { | ||||
|     throw std::invalid_argument("[tensordot] axes must have the same size."); | ||||
|   } | ||||
|   int csize = 1; | ||||
|   auto x = a; | ||||
|   auto y = b; | ||||
|   for (int i = 0; i < dims.first.size(); i++) { | ||||
|     if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) { | ||||
|       csize *= x.shape(dims.first.at(i)); | ||||
|   for (int i = 0; i < axes_a.size(); i++) { | ||||
|     if (x.shape(axes_a.at(i)) == y.shape(axes_b.at(i))) { | ||||
|       csize *= x.shape(axes_a.at(i)); | ||||
|     } else { | ||||
|       throw std::invalid_argument( | ||||
|           "[tensordot] a and b must have the same shape on the contracted axes."); | ||||
| @@ -3234,11 +3233,11 @@ array tensordot( | ||||
|  | ||||
|   std::vector<bool> cdims1(x.ndim(), false); | ||||
|   std::vector<bool> cdims2(y.ndim(), false); | ||||
|   for (const auto n : dims.first) { | ||||
|   for (const auto n : axes_a) { | ||||
|     int n_ = (n < 0) ? n + x.ndim() : n; | ||||
|     cdims1[n_] = true; | ||||
|   } | ||||
|   for (const auto n : dims.second) { | ||||
|   for (const auto n : axes_b) { | ||||
|     int n_ = (n < 0) ? n + y.ndim() : n; | ||||
|     cdims2[n_] = true; | ||||
|   } | ||||
| @@ -3255,10 +3254,10 @@ array tensordot( | ||||
|       rshape.emplace_back(a.shape(i)); | ||||
|     } | ||||
|   } | ||||
|   for (const auto x : dims.first) { | ||||
|   for (const auto x : axes_a) { | ||||
|     t1.emplace_back(x); | ||||
|   } | ||||
|   for (const auto x : dims.second) { | ||||
|   for (const auto x : axes_b) { | ||||
|     t2.emplace_back(x); | ||||
|   } | ||||
|   for (int i = 0; i < b.ndim(); i++) { | ||||
| @@ -3287,7 +3286,7 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) { | ||||
|         "[inner] a and b must have the same last dimension."); | ||||
|   } | ||||
|  | ||||
|   return tensordot(a, b, {{-1}, {-1}}, s); | ||||
|   return tensordot(a, b, {-1}, {-1}, s); | ||||
| } | ||||
|  | ||||
| /** Compute D = beta * C + alpha * (A @ B) */ | ||||
|   | ||||
| @@ -1110,17 +1110,18 @@ array dequantize( | ||||
|     int bits = 4, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| /** TensorDot returns a contraction of a and b over multiple dimensions. */ | ||||
| /** Returns a contraction of a and b over multiple dimensions. */ | ||||
| array tensordot( | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const int dims = 2, | ||||
|     const int axis = 2, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| array tensordot( | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const std::pair<std::vector<int>, std::vector<int>>& dims, | ||||
|     const std::vector<int>& axes_a, | ||||
|     const std::vector<int>& axes_b, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| /** Compute the outer product of two vectors. */ | ||||
|   | ||||
| @@ -3555,17 +3555,17 @@ void init_ops(py::module_& m) { | ||||
|       "tensordot", | ||||
|       [](const array& a, | ||||
|          const array& b, | ||||
|          const std::variant<int, std::vector<std::vector<int>>>& dims, | ||||
|          const std::variant<int, std::vector<std::vector<int>>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (auto pv = std::get_if<int>(&dims); pv) { | ||||
|         if (auto pv = std::get_if<int>(&axes); pv) { | ||||
|           return tensordot(a, b, *pv, s); | ||||
|         } else { | ||||
|           auto x = std::get<std::vector<std::vector<int>>>(dims); | ||||
|           auto& x = std::get<std::vector<std::vector<int>>>(axes); | ||||
|           if (x.size() != 2) { | ||||
|             throw std::invalid_argument( | ||||
|                 "[tensordot] dims must be a list of two lists."); | ||||
|                 "[tensordot] axes must be a list of two lists."); | ||||
|           } | ||||
|           return tensordot(a, b, {x[0], x[1]}, s); | ||||
|           return tensordot(a, b, x[0], x[1], s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|   | ||||
| @@ -2554,14 +2554,13 @@ TEST_CASE("tile") { | ||||
| TEST_CASE("tensordot") { | ||||
|   auto x = reshape(arange(60.), {3, 4, 5}); | ||||
|   auto y = reshape(arange(24.), {4, 3, 2}); | ||||
|   auto z = tensordot(x, y, {{1, 0}, {0, 1}}); | ||||
|   auto z = tensordot(x, y, {1, 0}, {0, 1}); | ||||
|   auto expected = array( | ||||
|       {4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2}); | ||||
|   CHECK(array_equal(z, expected).item<bool>()); | ||||
|   x = reshape(arange(360.), {3, 4, 5, 6}); | ||||
|   y = reshape(arange(360.), {6, 4, 5, 3}); | ||||
|   CHECK_THROWS_AS( | ||||
|       tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument); | ||||
|   CHECK_THROWS_AS(tensordot(x, y, {2, 1, 3}, {1, 2, 0}), std::invalid_argument); | ||||
|   x = reshape(arange(60.), {3, 4, 5}); | ||||
|   y = reshape(arange(120.), {4, 5, 6}); | ||||
|   z = tensordot(x, y, 2); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun