mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
@@ -1253,6 +1253,54 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The output array filled with ones.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eye",
|
||||
[](int n,
|
||||
py::object m_obj,
|
||||
py::object k_obj,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s) {
|
||||
int m = m_obj.is_none() ? n : m_obj.cast<int>();
|
||||
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
|
||||
return eye(n, m, k, dtype, s);
|
||||
},
|
||||
"n"_a,
|
||||
"m"_a = py::none(),
|
||||
"k"_a = py::none(),
|
||||
"dtype"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Create an identity matrix or a general diagonal matrix.
|
||||
|
||||
Args:
|
||||
n (int): The number of rows in the output.
|
||||
m (int, optional): The number of columns in the output. Defaults to n.
|
||||
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
|
||||
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||
stream (Stream, optional): Stream or device. Defaults to None.
|
||||
|
||||
Returns:
|
||||
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"identity",
|
||||
&identity,
|
||||
"n"_a,
|
||||
"dtype"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Create a square identity matrix.
|
||||
|
||||
Args:
|
||||
n (int): The number of rows and columns in the output.
|
||||
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||
stream (Stream, optional): Stream or device. Defaults to None.
|
||||
|
||||
Returns:
|
||||
array: An identity matrix of size n x n.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"allclose",
|
||||
&allclose,
|
||||
|
@@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual((a + b)[0, 0].item(), 2)
|
||||
|
||||
|
||||
def test_eye(self):
|
||||
eye_matrix = mx.eye(3)
|
||||
np_eye_matrix = np.eye(3)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test for non-square matrix
|
||||
eye_matrix = mx.eye(3, 4)
|
||||
np_eye_matrix = np.eye(3, 4)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test with positive k parameter
|
||||
eye_matrix = mx.eye(3, 4, k=1)
|
||||
np_eye_matrix = np.eye(3, 4, k=1)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test with negative k parameter
|
||||
eye_matrix = mx.eye(5, 6, k=-2)
|
||||
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user