mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
		@@ -150,6 +150,7 @@ Operations
 | 
			
		||||
   tensordot
 | 
			
		||||
   tile
 | 
			
		||||
   topk
 | 
			
		||||
   trace
 | 
			
		||||
   transpose
 | 
			
		||||
   tri
 | 
			
		||||
   tril
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										56
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							@@ -4113,6 +4113,62 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
array trace(
 | 
			
		||||
    const array& a,
 | 
			
		||||
    int offset,
 | 
			
		||||
    int axis1,
 | 
			
		||||
    int axis2,
 | 
			
		||||
    Dtype dtype,
 | 
			
		||||
    StreamOrDevice s /* = {} */) {
 | 
			
		||||
  int ndim = a.ndim();
 | 
			
		||||
  if (ndim < 2) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[trace] Array must have at least two dimensions, but got " << ndim
 | 
			
		||||
        << " dimensions.";
 | 
			
		||||
    throw std::invalid_argument(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
 | 
			
		||||
  if (ax1 < 0 || ax1 >= ndim) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[trace] Invalid axis1 " << axis1 << " for array with " << ndim
 | 
			
		||||
        << " dimensions.";
 | 
			
		||||
    throw std::out_of_range(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
 | 
			
		||||
  if (ax2 < 0 || ax2 >= ndim) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[trace] Invalid axis2 " << axis2 << " for array with " << ndim
 | 
			
		||||
        << " dimensions.";
 | 
			
		||||
    throw std::out_of_range(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (ax1 == ax2) {
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
        "[trace] axis1 and axis2 cannot be the same axis");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return sum(
 | 
			
		||||
      astype(diagonal(a, offset, axis1, axis2, s), dtype, s),
 | 
			
		||||
      /* axis = */ -1,
 | 
			
		||||
      /* keepdims = */ false,
 | 
			
		||||
      s);
 | 
			
		||||
}
 | 
			
		||||
array trace(
 | 
			
		||||
    const array& a,
 | 
			
		||||
    int offset,
 | 
			
		||||
    int axis1,
 | 
			
		||||
    int axis2,
 | 
			
		||||
    StreamOrDevice s /* = {} */) {
 | 
			
		||||
  auto dtype = a.dtype();
 | 
			
		||||
  return trace(a, offset, axis1, axis2, dtype, s);
 | 
			
		||||
}
 | 
			
		||||
array trace(const array& a, StreamOrDevice s /* = {} */) {
 | 
			
		||||
  auto dtype = a.dtype();
 | 
			
		||||
  return trace(a, 0, 0, 1, dtype, s);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<array> depends(
 | 
			
		||||
    const std::vector<array>& inputs,
 | 
			
		||||
    const std::vector<array>& dependencies) {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										16
									
								
								mlx/ops.h
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								mlx/ops.h
									
									
									
									
									
								
							@@ -1228,6 +1228,22 @@ array diagonal(
 | 
			
		||||
/** Extract diagonal from a 2d array or create a diagonal matrix. */
 | 
			
		||||
array diag(const array& a, int k = 0, StreamOrDevice s = {});
 | 
			
		||||
 | 
			
		||||
/** Return the sum along a specified diagonal in the given array. */
 | 
			
		||||
array trace(
 | 
			
		||||
    const array& a,
 | 
			
		||||
    int offset,
 | 
			
		||||
    int axis1,
 | 
			
		||||
    int axis2,
 | 
			
		||||
    Dtype dtype,
 | 
			
		||||
    StreamOrDevice s = {});
 | 
			
		||||
array trace(
 | 
			
		||||
    const array& a,
 | 
			
		||||
    int offset,
 | 
			
		||||
    int axis1,
 | 
			
		||||
    int axis2,
 | 
			
		||||
    StreamOrDevice s = {});
 | 
			
		||||
array trace(const array& a, StreamOrDevice s = {});
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Implements the identity function but allows injecting dependencies to other
 | 
			
		||||
 * arrays. This ensures that these other arrays will have been computed
 | 
			
		||||
 
 | 
			
		||||
@@ -4065,6 +4065,45 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The extracted diagonal or the constructed diagonal matrix.
 | 
			
		||||
        )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "trace",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
         int offset,
 | 
			
		||||
         int axis1,
 | 
			
		||||
         int axis2,
 | 
			
		||||
         std::optional<Dtype> dtype,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        if (!dtype.has_value()) {
 | 
			
		||||
          return trace(a, offset, axis1, axis2, s);
 | 
			
		||||
        }
 | 
			
		||||
        return trace(a, offset, axis1, axis2, dtype.value(), s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "offset"_a = 0,
 | 
			
		||||
      "axis1"_a = 0,
 | 
			
		||||
      "axis2"_a = 1,
 | 
			
		||||
      "dtype"_a = nb::none(),
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Return the sum along a specified diagonal in the given array.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          a (array): Input array
 | 
			
		||||
          offset (int, optional): Offset of the diagonal from the main diagonal.
 | 
			
		||||
            Can be positive or negative. Default: ``0``.
 | 
			
		||||
          axis1 (int, optional): The first axis of the 2-D sub-arrays from which
 | 
			
		||||
              the diagonals should be taken. Default: ``0``.
 | 
			
		||||
          axis2 (int, optional): The second axis of the 2-D sub-arrays from which
 | 
			
		||||
              the diagonals should be taken. Default: ``1``.
 | 
			
		||||
          dtype (Dtype, optional): Data type of the output array. If
 | 
			
		||||
              unspecified the output type is inferred from the input array.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: Sum of specified diagonal.
 | 
			
		||||
        )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "atleast_1d",
 | 
			
		||||
      [](const nb::args& arys, StreamOrDevice s) -> nb::object {
 | 
			
		||||
 
 | 
			
		||||
@@ -2091,6 +2091,38 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
        expected = mx.array(np.diag(x, k=-1))
 | 
			
		||||
        self.assertTrue(mx.array_equal(result, expected))
 | 
			
		||||
 | 
			
		||||
    def test_trace(self):
 | 
			
		||||
        a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3))
 | 
			
		||||
        a_np = np.arange(9, dtype=np.int64).reshape((3, 3))
 | 
			
		||||
 | 
			
		||||
        # Test 2D array
 | 
			
		||||
        result = mx.trace(a_mx)
 | 
			
		||||
        expected = np.trace(a_np)
 | 
			
		||||
        self.assertEqualArray(result, mx.array(expected))
 | 
			
		||||
 | 
			
		||||
        # Test dtype
 | 
			
		||||
        result = mx.trace(a_mx, dtype=mx.float16)
 | 
			
		||||
        expected = np.trace(a_np, dtype=np.float16)
 | 
			
		||||
        self.assertEqualArray(result, mx.array(expected))
 | 
			
		||||
 | 
			
		||||
        # Test offset
 | 
			
		||||
        result = mx.trace(a_mx, offset=1)
 | 
			
		||||
        expected = np.trace(a_np, offset=1)
 | 
			
		||||
        self.assertEqualArray(result, mx.array(expected))
 | 
			
		||||
 | 
			
		||||
        # Test axis1 and axis2
 | 
			
		||||
        b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3)
 | 
			
		||||
        b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3)
 | 
			
		||||
 | 
			
		||||
        result = mx.trace(b_mx, axis1=1, axis2=2)
 | 
			
		||||
        expected = np.trace(b_np, axis1=1, axis2=2)
 | 
			
		||||
        self.assertEqualArray(result, mx.array(expected))
 | 
			
		||||
 | 
			
		||||
        # Test offset, axis1, axis2, and dtype
 | 
			
		||||
        result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32)
 | 
			
		||||
        expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32)
 | 
			
		||||
        self.assertEqualArray(result, mx.array(expected))
 | 
			
		||||
 | 
			
		||||
    def test_atleast_1d(self):
 | 
			
		||||
        def compare_nested_lists(x, y):
 | 
			
		||||
            if isinstance(x, list) and isinstance(y, list):
 | 
			
		||||
 
 | 
			
		||||
@@ -3327,3 +3327,20 @@ TEST_CASE("test conv1d") {
 | 
			
		||||
    CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_CASE("test trace") {
 | 
			
		||||
  auto in = eye(3);
 | 
			
		||||
  auto out = trace(in).item<float>();
 | 
			
		||||
  CHECK_EQ(out, 3.0);
 | 
			
		||||
 | 
			
		||||
  in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32);
 | 
			
		||||
  auto out2 = trace(in).item<int>();
 | 
			
		||||
  CHECK_EQ(out2, 15);
 | 
			
		||||
 | 
			
		||||
  in = reshape(arange(8), {2, 2, 2});
 | 
			
		||||
  auto out3 = trace(in, 0, 0, 1);
 | 
			
		||||
  CHECK(array_equal(out3, array({6, 8}, {2})).item<bool>());
 | 
			
		||||
 | 
			
		||||
  auto out4 = trace(in, 0, 1, 2, float32);
 | 
			
		||||
  CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user