add mx.trace (#1143) (#1147)

* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
This commit is contained in:
Abe Leininger 2024-05-22 18:50:27 -04:00 committed by GitHub
parent e110ca11e2
commit 79ef49b2c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 161 additions and 0 deletions

View File

@ -150,6 +150,7 @@ Operations
tensordot
tile
topk
trace
transpose
tri
tril

View File

@ -4113,6 +4113,62 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
}
}
array trace(
const array& a,
int offset,
int axis1,
int axis2,
Dtype dtype,
StreamOrDevice s /* = {} */) {
int ndim = a.ndim();
if (ndim < 2) {
std::ostringstream msg;
msg << "[trace] Array must have at least two dimensions, but got " << ndim
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
if (ax1 < 0 || ax1 >= ndim) {
std::ostringstream msg;
msg << "[trace] Invalid axis1 " << axis1 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
if (ax2 < 0 || ax2 >= ndim) {
std::ostringstream msg;
msg << "[trace] Invalid axis2 " << axis2 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
if (ax1 == ax2) {
throw std::invalid_argument(
"[trace] axis1 and axis2 cannot be the same axis");
}
return sum(
astype(diagonal(a, offset, axis1, axis2, s), dtype, s),
/* axis = */ -1,
/* keepdims = */ false,
s);
}
array trace(
const array& a,
int offset,
int axis1,
int axis2,
StreamOrDevice s /* = {} */) {
auto dtype = a.dtype();
return trace(a, offset, axis1, axis2, dtype, s);
}
array trace(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = a.dtype();
return trace(a, 0, 0, 1, dtype, s);
}
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies) {

View File

@ -1228,6 +1228,22 @@ array diagonal(
/** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {});
/** Return the sum along a specified diagonal in the given array. */
array trace(
const array& a,
int offset,
int axis1,
int axis2,
Dtype dtype,
StreamOrDevice s = {});
array trace(
const array& a,
int offset,
int axis1,
int axis2,
StreamOrDevice s = {});
array trace(const array& a, StreamOrDevice s = {});
/**
* Implements the identity function but allows injecting dependencies to other
* arrays. This ensures that these other arrays will have been computed

View File

@ -4065,6 +4065,45 @@ void init_ops(nb::module_& m) {
Returns:
array: The extracted diagonal or the constructed diagonal matrix.
)pbdoc");
m.def(
"trace",
[](const array& a,
int offset,
int axis1,
int axis2,
std::optional<Dtype> dtype,
StreamOrDevice s) {
if (!dtype.has_value()) {
return trace(a, offset, axis1, axis2, s);
}
return trace(a, offset, axis1, axis2, dtype.value(), s);
},
nb::arg(),
"offset"_a = 0,
"axis1"_a = 0,
"axis2"_a = 1,
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the sum along a specified diagonal in the given array.
Args:
a (array): Input array
offset (int, optional): Offset of the diagonal from the main diagonal.
Can be positive or negative. Default: ``0``.
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``0``.
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
the diagonals should be taken. Default: ``1``.
dtype (Dtype, optional): Data type of the output array. If
unspecified the output type is inferred from the input array.
Returns:
array: Sum of specified diagonal.
)pbdoc");
m.def(
"atleast_1d",
[](const nb::args& arys, StreamOrDevice s) -> nb::object {

View File

@ -2091,6 +2091,38 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.diag(x, k=-1))
self.assertTrue(mx.array_equal(result, expected))
def test_trace(self):
a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3))
a_np = np.arange(9, dtype=np.int64).reshape((3, 3))
# Test 2D array
result = mx.trace(a_mx)
expected = np.trace(a_np)
self.assertEqualArray(result, mx.array(expected))
# Test dtype
result = mx.trace(a_mx, dtype=mx.float16)
expected = np.trace(a_np, dtype=np.float16)
self.assertEqualArray(result, mx.array(expected))
# Test offset
result = mx.trace(a_mx, offset=1)
expected = np.trace(a_np, offset=1)
self.assertEqualArray(result, mx.array(expected))
# Test axis1 and axis2
b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3)
b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3)
result = mx.trace(b_mx, axis1=1, axis2=2)
expected = np.trace(b_np, axis1=1, axis2=2)
self.assertEqualArray(result, mx.array(expected))
# Test offset, axis1, axis2, and dtype
result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32)
expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32)
self.assertEqualArray(result, mx.array(expected))
def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):

View File

@ -3327,3 +3327,20 @@ TEST_CASE("test conv1d") {
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
}
}
TEST_CASE("test trace") {
auto in = eye(3);
auto out = trace(in).item<float>();
CHECK_EQ(out, 3.0);
in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32);
auto out2 = trace(in).item<int>();
CHECK_EQ(out2, 15);
in = reshape(arange(8), {2, 2, 2});
auto out3 = trace(in, 0, 0, 1);
CHECK(array_equal(out3, array({6, 8}, {2})).item<bool>());
auto out4 = trace(in, 0, 1, 2, float32);
CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
}