mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
		| @@ -2091,6 +2091,38 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         expected = mx.array(np.diag(x, k=-1)) | ||||
|         self.assertTrue(mx.array_equal(result, expected)) | ||||
|  | ||||
|     def test_trace(self): | ||||
|         a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3)) | ||||
|         a_np = np.arange(9, dtype=np.int64).reshape((3, 3)) | ||||
|  | ||||
|         # Test 2D array | ||||
|         result = mx.trace(a_mx) | ||||
|         expected = np.trace(a_np) | ||||
|         self.assertEqualArray(result, mx.array(expected)) | ||||
|  | ||||
|         # Test dtype | ||||
|         result = mx.trace(a_mx, dtype=mx.float16) | ||||
|         expected = np.trace(a_np, dtype=np.float16) | ||||
|         self.assertEqualArray(result, mx.array(expected)) | ||||
|  | ||||
|         # Test offset | ||||
|         result = mx.trace(a_mx, offset=1) | ||||
|         expected = np.trace(a_np, offset=1) | ||||
|         self.assertEqualArray(result, mx.array(expected)) | ||||
|  | ||||
|         # Test axis1 and axis2 | ||||
|         b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3) | ||||
|         b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3) | ||||
|  | ||||
|         result = mx.trace(b_mx, axis1=1, axis2=2) | ||||
|         expected = np.trace(b_np, axis1=1, axis2=2) | ||||
|         self.assertEqualArray(result, mx.array(expected)) | ||||
|  | ||||
|         # Test offset, axis1, axis2, and dtype | ||||
|         result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32) | ||||
|         expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32) | ||||
|         self.assertEqualArray(result, mx.array(expected)) | ||||
|  | ||||
|     def test_atleast_1d(self): | ||||
|         def compare_nested_lists(x, y): | ||||
|             if isinstance(x, list) and isinstance(y, list): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Abe Leininger
					Abe Leininger