diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index ced813e90..745d12d50 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -150,6 +150,7 @@ Operations tensordot tile topk + trace transpose tri tril diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 006789306..22680dd7f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4113,6 +4113,62 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) { } } +array trace( + const array& a, + int offset, + int axis1, + int axis2, + Dtype dtype, + StreamOrDevice s /* = {} */) { + int ndim = a.ndim(); + if (ndim < 2) { + std::ostringstream msg; + msg << "[trace] Array must have at least two dimensions, but got " << ndim + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1; + if (ax1 < 0 || ax1 >= ndim) { + std::ostringstream msg; + msg << "[trace] Invalid axis1 " << axis1 << " for array with " << ndim + << " dimensions."; + throw std::out_of_range(msg.str()); + } + + auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2; + if (ax2 < 0 || ax2 >= ndim) { + std::ostringstream msg; + msg << "[trace] Invalid axis2 " << axis2 << " for array with " << ndim + << " dimensions."; + throw std::out_of_range(msg.str()); + } + + if (ax1 == ax2) { + throw std::invalid_argument( + "[trace] axis1 and axis2 cannot be the same axis"); + } + + return sum( + astype(diagonal(a, offset, axis1, axis2, s), dtype, s), + /* axis = */ -1, + /* keepdims = */ false, + s); +} +array trace( + const array& a, + int offset, + int axis1, + int axis2, + StreamOrDevice s /* = {} */) { + auto dtype = a.dtype(); + return trace(a, offset, axis1, axis2, dtype, s); +} +array trace(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = a.dtype(); + return trace(a, 0, 0, 1, dtype, s); +} + std::vector depends( const std::vector& inputs, const std::vector& dependencies) { diff --git a/mlx/ops.h b/mlx/ops.h index 6a6b7ac05..17bc11255 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1228,6 +1228,22 @@ array diagonal( /** Extract diagonal from a 2d array or create a diagonal matrix. */ array diag(const array& a, int k = 0, StreamOrDevice s = {}); +/** Return the sum along a specified diagonal in the given array. */ +array trace( + const array& a, + int offset, + int axis1, + int axis2, + Dtype dtype, + StreamOrDevice s = {}); +array trace( + const array& a, + int offset, + int axis1, + int axis2, + StreamOrDevice s = {}); +array trace(const array& a, StreamOrDevice s = {}); + /** * Implements the identity function but allows injecting dependencies to other * arrays. This ensures that these other arrays will have been computed diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8a4be83b0..c601ca0c5 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4065,6 +4065,45 @@ void init_ops(nb::module_& m) { Returns: array: The extracted diagonal or the constructed diagonal matrix. )pbdoc"); + m.def( + "trace", + [](const array& a, + int offset, + int axis1, + int axis2, + std::optional dtype, + StreamOrDevice s) { + if (!dtype.has_value()) { + return trace(a, offset, axis1, axis2, s); + } + return trace(a, offset, axis1, axis2, dtype.value(), s); + }, + nb::arg(), + "offset"_a = 0, + "axis1"_a = 0, + "axis2"_a = 1, + "dtype"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the sum along a specified diagonal in the given array. + + Args: + a (array): Input array + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Default: ``0``. + axis1 (int, optional): The first axis of the 2-D sub-arrays from which + the diagonals should be taken. Default: ``0``. + axis2 (int, optional): The second axis of the 2-D sub-arrays from which + the diagonals should be taken. Default: ``1``. + dtype (Dtype, optional): Data type of the output array. If + unspecified the output type is inferred from the input array. + + Returns: + array: Sum of specified diagonal. + )pbdoc"); m.def( "atleast_1d", [](const nb::args& arys, StreamOrDevice s) -> nb::object { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index ea84a5007..0256154dd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2091,6 +2091,38 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array(np.diag(x, k=-1)) self.assertTrue(mx.array_equal(result, expected)) + def test_trace(self): + a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3)) + a_np = np.arange(9, dtype=np.int64).reshape((3, 3)) + + # Test 2D array + result = mx.trace(a_mx) + expected = np.trace(a_np) + self.assertEqualArray(result, mx.array(expected)) + + # Test dtype + result = mx.trace(a_mx, dtype=mx.float16) + expected = np.trace(a_np, dtype=np.float16) + self.assertEqualArray(result, mx.array(expected)) + + # Test offset + result = mx.trace(a_mx, offset=1) + expected = np.trace(a_np, offset=1) + self.assertEqualArray(result, mx.array(expected)) + + # Test axis1 and axis2 + b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3) + b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3) + + result = mx.trace(b_mx, axis1=1, axis2=2) + expected = np.trace(b_np, axis1=1, axis2=2) + self.assertEqualArray(result, mx.array(expected)) + + # Test offset, axis1, axis2, and dtype + result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32) + expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32) + self.assertEqualArray(result, mx.array(expected)) + def test_atleast_1d(self): def compare_nested_lists(x, y): if isinstance(x, list) and isinstance(y, list): diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index eb49cbc46..c1bbeb051 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3327,3 +3327,20 @@ TEST_CASE("test conv1d") { CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item()); } } + +TEST_CASE("test trace") { + auto in = eye(3); + auto out = trace(in).item(); + CHECK_EQ(out, 3.0); + + in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32); + auto out2 = trace(in).item(); + CHECK_EQ(out2, 15); + + in = reshape(arange(8), {2, 2, 2}); + auto out3 = trace(in, 0, 0, 1); + CHECK(array_equal(out3, array({6, 8}, {2})).item()); + + auto out4 = trace(in, 0, 1, 2, float32); + CHECK(array_equal(out4, array({3, 11}, {2})).item()); +}