Added eye/identity ops (#119)

`eye` and `identity` C++ and Python ops
This commit is contained in:
Cyril Zakka, MD 2023-12-11 12:38:17 -08:00 committed by GitHub
parent 69505b4e9b
commit e080290ba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 175 additions and 0 deletions

View File

@ -38,9 +38,11 @@ Operations
erfinv erfinv
exp exp
expand_dims expand_dims
eye
full full
greater greater
greater_equal greater_equal
identity
less less
less_equal less_equal
load load

View File

@ -194,6 +194,30 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
return ones(a.shape(), a.dtype(), to_stream(s)); return ones(a.shape(), a.dtype(), to_stream(s));
} }
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
if (n <= 0 || m <= 0) {
throw std::invalid_argument("N and M must be positive integers.");
}
array result = zeros({n * m}, dtype, s);
if (k >= m || -k >= n) {
return reshape(result, {n, m}, s);
}
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
int start_index = (k >= 0) ? k : -k * m;
array diag_indices_array = arange(
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
array ones_array = ones({diagonal_length, 1}, dtype, s);
result = scatter(result, diag_indices_array, ones_array, 0, s);
return reshape(result, {n, m}, s);
}
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
return eye(n, n, 0, dtype, s);
}
array reshape( array reshape(
const array& a, const array& a,
std::vector<int> shape, std::vector<int> shape,

View File

@ -87,6 +87,29 @@ inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
} }
array ones_like(const array& a, StreamOrDevice s = {}); array ones_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
* k, and zeros everywhere else. */
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
return eye(n, n, 0, dtype, s);
}
inline array eye(int n, int m, StreamOrDevice s = {}) {
return eye(n, m, 0, float32, s);
}
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
return eye(n, m, k, float32, s);
}
inline array eye(int n, StreamOrDevice s = {}) {
return eye(n, n, 0, float32, s);
}
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
* diagonal. */
array identity(int n, Dtype dtype, StreamOrDevice s = {});
inline array identity(int n, StreamOrDevice s = {}) {
return identity(n, float32, s);
}
/** array manipulation */ /** array manipulation */
/** Reshape an array to the given shape. */ /** Reshape an array to the given shape. */

View File

@ -1253,6 +1253,54 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: The output array filled with ones. array: The output array filled with ones.
)pbdoc"); )pbdoc");
m.def(
"eye",
[](int n,
py::object m_obj,
py::object k_obj,
Dtype dtype,
StreamOrDevice s) {
int m = m_obj.is_none() ? n : m_obj.cast<int>();
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
return eye(n, m, k, dtype, s);
},
"n"_a,
"m"_a = py::none(),
"k"_a = py::none(),
"dtype"_a = std::nullopt,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Create an identity matrix or a general diagonal matrix.
Args:
n (int): The number of rows in the output.
m (int, optional): The number of columns in the output. Defaults to n.
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
stream (Stream, optional): Stream or device. Defaults to None.
Returns:
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
)pbdoc");
m.def(
"identity",
&identity,
"n"_a,
"dtype"_a = std::nullopt,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
Create a square identity matrix.
Args:
n (int): The number of rows and columns in the output.
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
stream (Stream, optional): Stream or device. Defaults to None.
Returns:
array: An identity matrix of size n x n.
)pbdoc");
m.def( m.def(
"allclose", "allclose",
&allclose, &allclose,

View File

@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual((a + b)[0, 0].item(), 2) self.assertEqual((a + b)[0, 0].item(), 2)
def test_eye(self):
eye_matrix = mx.eye(3)
np_eye_matrix = np.eye(3)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test for non-square matrix
eye_matrix = mx.eye(3, 4)
np_eye_matrix = np.eye(3, 4)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with positive k parameter
eye_matrix = mx.eye(3, 4, k=1)
np_eye_matrix = np.eye(3, 4, k=1)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with negative k parameter
eye_matrix = mx.eye(5, 6, k=-2)
np_eye_matrix = np.eye(5, 6, k=-2)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1926,3 +1926,58 @@ TEST_CASE("test where") {
expected = array({1, 2, 2, 1}, {2, 2}); expected = array({1, 2, 2, 1}, {2, 2});
CHECK(array_equal(where(condition, x, y), expected).item<bool>()); CHECK(array_equal(where(condition, x, y), expected).item<bool>());
} }
TEST_CASE("test eye") {
auto eye_3 = eye(3);
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
auto expected_eye_3 =
array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
auto eye_3x2 = eye(3, 2);
CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2});
auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
}
TEST_CASE("test identity") {
auto id_4 = identity(4);
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
auto expected_id_4 = array(
{1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
1.0f},
{4, 4});
CHECK(array_equal(id_4, expected_id_4).item<bool>());
}
TEST_CASE("test eye with positive k offset") {
auto eye_3_k1 = eye(3, 4, 1);
CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4});
auto expected_eye_3_k1 = array(
{0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{3, 4});
CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>());
}
TEST_CASE("test eye with negative k offset") {
auto eye_4_k_minus1 = eye(4, 3, -1);
CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3});
auto expected_eye_4_k_minus1 = array(
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{4, 3});
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
}