mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
@@ -2091,6 +2091,38 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = mx.array(np.diag(x, k=-1))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
def test_trace(self):
|
||||
a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3))
|
||||
a_np = np.arange(9, dtype=np.int64).reshape((3, 3))
|
||||
|
||||
# Test 2D array
|
||||
result = mx.trace(a_mx)
|
||||
expected = np.trace(a_np)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test dtype
|
||||
result = mx.trace(a_mx, dtype=mx.float16)
|
||||
expected = np.trace(a_np, dtype=np.float16)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test offset
|
||||
result = mx.trace(a_mx, offset=1)
|
||||
expected = np.trace(a_np, offset=1)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test axis1 and axis2
|
||||
b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3)
|
||||
b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3)
|
||||
|
||||
result = mx.trace(b_mx, axis1=1, axis2=2)
|
||||
expected = np.trace(b_np, axis1=1, axis2=2)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test offset, axis1, axis2, and dtype
|
||||
result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32)
|
||||
expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
|
||||
Reference in New Issue
Block a user