Added eye/identity ops (#119)

`eye` and `identity` C++ and Python ops
This commit is contained in:
Cyril Zakka, MD
2023-12-11 12:38:17 -08:00
committed by GitHub
parent 69505b4e9b
commit e080290ba4
6 changed files with 175 additions and 0 deletions

View File

@@ -1253,6 +1253,54 @@ void init_ops(py::module_& m) {
Returns:
array: The output array filled with ones.
)pbdoc");
m.def(
"eye",
[](int n,
py::object m_obj,
py::object k_obj,
Dtype dtype,
StreamOrDevice s) {
int m = m_obj.is_none() ? n : m_obj.cast<int>();
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
return eye(n, m, k, dtype, s);
},
"n"_a,
"m"_a = py::none(),
"k"_a = py::none(),
"dtype"_a = std::nullopt,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Create an identity matrix or a general diagonal matrix.
Args:
n (int): The number of rows in the output.
m (int, optional): The number of columns in the output. Defaults to n.
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
stream (Stream, optional): Stream or device. Defaults to None.
Returns:
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
)pbdoc");
m.def(
"identity",
&identity,
"n"_a,
"dtype"_a = std::nullopt,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Create a square identity matrix.
Args:
n (int): The number of rows and columns in the output.
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
stream (Stream, optional): Stream or device. Defaults to None.
Returns:
array: An identity matrix of size n x n.
)pbdoc");
m.def(
"allclose",
&allclose,

View File

@@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual((a + b)[0, 0].item(), 2)
def test_eye(self):
eye_matrix = mx.eye(3)
np_eye_matrix = np.eye(3)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test for non-square matrix
eye_matrix = mx.eye(3, 4)
np_eye_matrix = np.eye(3, 4)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with positive k parameter
eye_matrix = mx.eye(3, 4, k=1)
np_eye_matrix = np.eye(3, 4, k=1)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with negative k parameter
eye_matrix = mx.eye(5, 6, k=-2)
np_eye_matrix = np.eye(5, 6, k=-2)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
if __name__ == "__main__":
unittest.main()