mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
parent
e110ca11e2
commit
79ef49b2c2
@ -150,6 +150,7 @@ Operations
|
||||
tensordot
|
||||
tile
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
|
56
mlx/ops.cpp
56
mlx/ops.cpp
@ -4113,6 +4113,62 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
}
|
||||
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
int ndim = a.ndim();
|
||||
if (ndim < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[trace] Array must have at least two dimensions, but got " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
|
||||
if (ax1 < 0 || ax1 >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[trace] Invalid axis1 " << axis1 << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
|
||||
auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
|
||||
if (ax2 < 0 || ax2 >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[trace] Invalid axis2 " << axis2 << " for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::out_of_range(msg.str());
|
||||
}
|
||||
|
||||
if (ax1 == ax2) {
|
||||
throw std::invalid_argument(
|
||||
"[trace] axis1 and axis2 cannot be the same axis");
|
||||
}
|
||||
|
||||
return sum(
|
||||
astype(diagonal(a, offset, axis1, axis2, s), dtype, s),
|
||||
/* axis = */ -1,
|
||||
/* keepdims = */ false,
|
||||
s);
|
||||
}
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = a.dtype();
|
||||
return trace(a, offset, axis1, axis2, dtype, s);
|
||||
}
|
||||
array trace(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = a.dtype();
|
||||
return trace(a, 0, 0, 1, dtype, s);
|
||||
}
|
||||
|
||||
std::vector<array> depends(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& dependencies) {
|
||||
|
16
mlx/ops.h
16
mlx/ops.h
@ -1228,6 +1228,22 @@ array diagonal(
|
||||
/** Extract diagonal from a 2d array or create a diagonal matrix. */
|
||||
array diag(const array& a, int k = 0, StreamOrDevice s = {});
|
||||
|
||||
/** Return the sum along a specified diagonal in the given array. */
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {});
|
||||
array trace(
|
||||
const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
StreamOrDevice s = {});
|
||||
array trace(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Implements the identity function but allows injecting dependencies to other
|
||||
* arrays. This ensures that these other arrays will have been computed
|
||||
|
@ -4065,6 +4065,45 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The extracted diagonal or the constructed diagonal matrix.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"trace",
|
||||
[](const array& a,
|
||||
int offset,
|
||||
int axis1,
|
||||
int axis2,
|
||||
std::optional<Dtype> dtype,
|
||||
StreamOrDevice s) {
|
||||
if (!dtype.has_value()) {
|
||||
return trace(a, offset, axis1, axis2, s);
|
||||
}
|
||||
return trace(a, offset, axis1, axis2, dtype.value(), s);
|
||||
},
|
||||
nb::arg(),
|
||||
"offset"_a = 0,
|
||||
"axis1"_a = 0,
|
||||
"axis2"_a = 1,
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return the sum along a specified diagonal in the given array.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
offset (int, optional): Offset of the diagonal from the main diagonal.
|
||||
Can be positive or negative. Default: ``0``.
|
||||
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
|
||||
the diagonals should be taken. Default: ``0``.
|
||||
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
|
||||
the diagonals should be taken. Default: ``1``.
|
||||
dtype (Dtype, optional): Data type of the output array. If
|
||||
unspecified the output type is inferred from the input array.
|
||||
|
||||
Returns:
|
||||
array: Sum of specified diagonal.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_1d",
|
||||
[](const nb::args& arys, StreamOrDevice s) -> nb::object {
|
||||
|
@ -2091,6 +2091,38 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = mx.array(np.diag(x, k=-1))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
def test_trace(self):
|
||||
a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3))
|
||||
a_np = np.arange(9, dtype=np.int64).reshape((3, 3))
|
||||
|
||||
# Test 2D array
|
||||
result = mx.trace(a_mx)
|
||||
expected = np.trace(a_np)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test dtype
|
||||
result = mx.trace(a_mx, dtype=mx.float16)
|
||||
expected = np.trace(a_np, dtype=np.float16)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test offset
|
||||
result = mx.trace(a_mx, offset=1)
|
||||
expected = np.trace(a_np, offset=1)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test axis1 and axis2
|
||||
b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3)
|
||||
b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3)
|
||||
|
||||
result = mx.trace(b_mx, axis1=1, axis2=2)
|
||||
expected = np.trace(b_np, axis1=1, axis2=2)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
# Test offset, axis1, axis2, and dtype
|
||||
result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32)
|
||||
expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32)
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
|
@ -3327,3 +3327,20 @@ TEST_CASE("test conv1d") {
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test trace") {
|
||||
auto in = eye(3);
|
||||
auto out = trace(in).item<float>();
|
||||
CHECK_EQ(out, 3.0);
|
||||
|
||||
in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32);
|
||||
auto out2 = trace(in).item<int>();
|
||||
CHECK_EQ(out2, 15);
|
||||
|
||||
in = reshape(arange(8), {2, 2, 2});
|
||||
auto out3 = trace(in, 0, 0, 1);
|
||||
CHECK(array_equal(out3, array({6, 8}, {2})).item<bool>());
|
||||
|
||||
auto out4 = trace(in, 0, 1, 2, float32);
|
||||
CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user