mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
This commit is contained in:
		| @@ -2230,6 +2230,36 @@ void init_ops(py::module_& m) { | ||||
|         Returns: | ||||
|             array: The concatenated array. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "stack", | ||||
|       [](const std::vector<array>& arrays, | ||||
|          std::optional<int> axis, | ||||
|          StreamOrDevice s) { | ||||
|         if (axis.has_value()) { | ||||
|           return stack(arrays, axis.value(), s); | ||||
|         } else { | ||||
|           return stack(arrays, s); | ||||
|         } | ||||
|       }, | ||||
|       "arrays"_a, | ||||
|       py::pos_only(), | ||||
|       "axis"_a = 0, | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|       stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|       Stacks the arrays along a new axis. | ||||
|  | ||||
|       Args: | ||||
|           arrays (list(array)): A list of arrays to stack. | ||||
|           axis (int, optional): The axis in the result array along which the | ||||
|             input arrays are stacked. Defaults to ``0``.  | ||||
|           stream (Stream, optional): Stream or device. Defaults to ``None``. | ||||
|  | ||||
|       Returns: | ||||
|           array: The resulting stacked array. | ||||
|     )pbdoc"); | ||||
|   m.def( | ||||
|       "pad", | ||||
|       [](const array& a, | ||||
|   | ||||
| @@ -1371,6 +1371,37 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         np_eye_matrix = np.eye(5, 6, k=-2) | ||||
|         self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) | ||||
|  | ||||
|     def test_stack(self): | ||||
|         a = mx.ones((2,)) | ||||
|         np_a = np.ones((2,)) | ||||
|         b = mx.ones((2,)) | ||||
|         np_b = np.ones((2,)) | ||||
|  | ||||
|         # One dimensional stack axis=0 | ||||
|         c = mx.stack([a, b]) | ||||
|         np_c = np.stack([np_a, np_b]) | ||||
|         self.assertTrue(np.array_equal(c, np_c)) | ||||
|  | ||||
|         # One dimensional stack axis=1 | ||||
|         c = mx.stack([a, b], axis=1) | ||||
|         np_c = np.stack([np_a, np_b], axis=1) | ||||
|         self.assertTrue(np.array_equal(c, np_c)) | ||||
|  | ||||
|         a = mx.ones((1, 2)) | ||||
|         np_a = np.ones((1, 2)) | ||||
|         b = mx.ones((1, 2)) | ||||
|         np_b = np.ones((1, 2)) | ||||
|  | ||||
|         # Two dimensional stack axis=0 | ||||
|         c = mx.stack([a, b]) | ||||
|         np_c = np.stack([np_a, np_b]) | ||||
|         self.assertTrue(np.array_equal(c, np_c)) | ||||
|  | ||||
|         # Two dimensional stack axis=1 | ||||
|         c = mx.stack([a, b], axis=1) | ||||
|         np_c = np.stack([np_a, np_b], axis=1) | ||||
|         self.assertTrue(np.array_equal(c, np_c)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jason
					Jason