add mx.trace (#1143) (#1147)

* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
This commit is contained in:
Abe Leininger
2024-05-22 18:50:27 -04:00
committed by GitHub
parent e110ca11e2
commit 79ef49b2c2
6 changed files with 161 additions and 0 deletions

View File

@@ -2091,6 +2091,38 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.diag(x, k=-1))
self.assertTrue(mx.array_equal(result, expected))
def test_trace(self):
a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3))
a_np = np.arange(9, dtype=np.int64).reshape((3, 3))
# Test 2D array
result = mx.trace(a_mx)
expected = np.trace(a_np)
self.assertEqualArray(result, mx.array(expected))
# Test dtype
result = mx.trace(a_mx, dtype=mx.float16)
expected = np.trace(a_np, dtype=np.float16)
self.assertEqualArray(result, mx.array(expected))
# Test offset
result = mx.trace(a_mx, offset=1)
expected = np.trace(a_np, offset=1)
self.assertEqualArray(result, mx.array(expected))
# Test axis1 and axis2
b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3)
b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3)
result = mx.trace(b_mx, axis1=1, axis2=2)
expected = np.trace(b_np, axis1=1, axis2=2)
self.assertEqualArray(result, mx.array(expected))
# Test offset, axis1, axis2, and dtype
result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32)
expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32)
self.assertEqualArray(result, mx.array(expected))
def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):