diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 450588536..b9a4c9066 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -38,9 +38,11 @@ Operations erfinv exp expand_dims + eye full greater greater_equal + identity less less_equal load diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 16f4bcd9a..e85683a48 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -194,6 +194,30 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) { return ones(a.shape(), a.dtype(), to_stream(s)); } +array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) { + if (n <= 0 || m <= 0) { + throw std::invalid_argument("N and M must be positive integers."); + } + array result = zeros({n * m}, dtype, s); + if (k >= m || -k >= n) { + return reshape(result, {n, m}, s); + } + + int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m); + int start_index = (k >= 0) ? k : -k * m; + + array diag_indices_array = arange( + start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s); + array ones_array = ones({diagonal_length, 1}, dtype, s); + result = scatter(result, diag_indices_array, ones_array, 0, s); + + return reshape(result, {n, m}, s); +} + +array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) { + return eye(n, n, 0, dtype, s); +} + array reshape( const array& a, std::vector shape, diff --git a/mlx/ops.h b/mlx/ops.h index 52d89ed73..560697259 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -87,6 +87,29 @@ inline array ones(const std::vector& shape, StreamOrDevice s = {}) { } array ones_like(const array& a, StreamOrDevice s = {}); +/** Fill an array of the given shape (n,m) with ones in the specified diagonal + * k, and zeros everywhere else. */ +array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {}); +inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) { + return eye(n, n, 0, dtype, s); +} +inline array eye(int n, int m, StreamOrDevice s = {}) { + return eye(n, m, 0, float32, s); +} +inline array eye(int n, int m, int k, StreamOrDevice s = {}) { + return eye(n, m, k, float32, s); +} +inline array eye(int n, StreamOrDevice s = {}) { + return eye(n, n, 0, float32, s); +} + +/** Create a square matrix of shape (n,n) of zeros, and ones in the major + * diagonal. */ +array identity(int n, Dtype dtype, StreamOrDevice s = {}); +inline array identity(int n, StreamOrDevice s = {}) { + return identity(n, float32, s); +} + /** array manipulation */ /** Reshape an array to the given shape. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 6103186c4..e25da7f38 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1253,6 +1253,54 @@ void init_ops(py::module_& m) { Returns: array: The output array filled with ones. )pbdoc"); + m.def( + "eye", + [](int n, + py::object m_obj, + py::object k_obj, + Dtype dtype, + StreamOrDevice s) { + int m = m_obj.is_none() ? n : m_obj.cast(); + int k = k_obj.is_none() ? 0 : k_obj.cast(); + return eye(n, m, k, dtype, s); + }, + "n"_a, + "m"_a = py::none(), + "k"_a = py::none(), + "dtype"_a = std::nullopt, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Create an identity matrix or a general diagonal matrix. + + Args: + n (int): The number of rows in the output. + m (int, optional): The number of columns in the output. Defaults to n. + k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal). + dtype (Dtype, optional): Data type of the output array. Defaults to float32. + stream (Stream, optional): Stream or device. Defaults to None. + + Returns: + array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. + )pbdoc"); + m.def( + "identity", + &identity, + "n"_a, + "dtype"_a = std::nullopt, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Create a square identity matrix. + + Args: + n (int): The number of rows and columns in the output. + dtype (Dtype, optional): Data type of the output array. Defaults to float32. + stream (Stream, optional): Stream or device. Defaults to None. + + Returns: + array: An identity matrix of size n x n. + )pbdoc"); m.def( "allclose", &allclose, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 65b9daed5..57c6ea2b0 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual((a + b)[0, 0].item(), 2) + def test_eye(self): + eye_matrix = mx.eye(3) + np_eye_matrix = np.eye(3) + self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) + + # Test for non-square matrix + eye_matrix = mx.eye(3, 4) + np_eye_matrix = np.eye(3, 4) + self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) + + # Test with positive k parameter + eye_matrix = mx.eye(3, 4, k=1) + np_eye_matrix = np.eye(3, 4, k=1) + self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) + + # Test with negative k parameter + eye_matrix = mx.eye(5, 6, k=-2) + np_eye_matrix = np.eye(5, 6, k=-2) + self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) + + + + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 5dcf8658d..b3b32ea1d 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1926,3 +1926,58 @@ TEST_CASE("test where") { expected = array({1, 2, 2, 1}, {2, 2}); CHECK(array_equal(where(condition, x, y), expected).item()); } + +TEST_CASE("test eye") { + auto eye_3 = eye(3); + CHECK_EQ(eye_3.shape(), std::vector{3, 3}); + auto expected_eye_3 = + array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3}); + CHECK(array_equal(eye_3, expected_eye_3).item()); + + auto eye_3x2 = eye(3, 2); + CHECK_EQ(eye_3x2.shape(), std::vector{3, 2}); + auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2}); + CHECK(array_equal(eye_3x2, expected_eye_3x2).item()); +} + +TEST_CASE("test identity") { + auto id_4 = identity(4); + CHECK_EQ(id_4.shape(), std::vector{4, 4}); + auto expected_id_4 = array( + {1.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + 1.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + 1.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + 1.0f}, + {4, 4}); + CHECK(array_equal(id_4, expected_id_4).item()); +} + +TEST_CASE("test eye with positive k offset") { + auto eye_3_k1 = eye(3, 4, 1); + CHECK_EQ(eye_3_k1.shape(), std::vector{3, 4}); + auto expected_eye_3_k1 = array( + {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}, + {3, 4}); + CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item()); +} + +TEST_CASE("test eye with negative k offset") { + auto eye_4_k_minus1 = eye(4, 3, -1); + CHECK_EQ(eye_4_k_minus1.shape(), std::vector{4, 3}); + auto expected_eye_4_k_minus1 = array( + {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, + {4, 3}); + CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item()); +} \ No newline at end of file