mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
69505b4e9b
commit
e080290ba4
@ -38,9 +38,11 @@ Operations
|
||||
erfinv
|
||||
exp
|
||||
expand_dims
|
||||
eye
|
||||
full
|
||||
greater
|
||||
greater_equal
|
||||
identity
|
||||
less
|
||||
less_equal
|
||||
load
|
||||
|
24
mlx/ops.cpp
24
mlx/ops.cpp
@ -194,6 +194,30 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return ones(a.shape(), a.dtype(), to_stream(s));
|
||||
}
|
||||
|
||||
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
if (n <= 0 || m <= 0) {
|
||||
throw std::invalid_argument("N and M must be positive integers.");
|
||||
}
|
||||
array result = zeros({n * m}, dtype, s);
|
||||
if (k >= m || -k >= n) {
|
||||
return reshape(result, {n, m}, s);
|
||||
}
|
||||
|
||||
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
|
||||
int start_index = (k >= 0) ? k : -k * m;
|
||||
|
||||
array diag_indices_array = arange(
|
||||
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
|
||||
array ones_array = ones({diagonal_length, 1}, dtype, s);
|
||||
result = scatter(result, diag_indices_array, ones_array, 0, s);
|
||||
|
||||
return reshape(result, {n, m}, s);
|
||||
}
|
||||
|
||||
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
return eye(n, n, 0, dtype, s);
|
||||
}
|
||||
|
||||
array reshape(
|
||||
const array& a,
|
||||
std::vector<int> shape,
|
||||
|
23
mlx/ops.h
23
mlx/ops.h
@ -87,6 +87,29 @@ inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||
}
|
||||
array ones_like(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
|
||||
* k, and zeros everywhere else. */
|
||||
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
|
||||
return eye(n, n, 0, dtype, s);
|
||||
}
|
||||
inline array eye(int n, int m, StreamOrDevice s = {}) {
|
||||
return eye(n, m, 0, float32, s);
|
||||
}
|
||||
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
|
||||
return eye(n, m, k, float32, s);
|
||||
}
|
||||
inline array eye(int n, StreamOrDevice s = {}) {
|
||||
return eye(n, n, 0, float32, s);
|
||||
}
|
||||
|
||||
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
|
||||
* diagonal. */
|
||||
array identity(int n, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array identity(int n, StreamOrDevice s = {}) {
|
||||
return identity(n, float32, s);
|
||||
}
|
||||
|
||||
/** array manipulation */
|
||||
|
||||
/** Reshape an array to the given shape. */
|
||||
|
@ -1253,6 +1253,54 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The output array filled with ones.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eye",
|
||||
[](int n,
|
||||
py::object m_obj,
|
||||
py::object k_obj,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s) {
|
||||
int m = m_obj.is_none() ? n : m_obj.cast<int>();
|
||||
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
|
||||
return eye(n, m, k, dtype, s);
|
||||
},
|
||||
"n"_a,
|
||||
"m"_a = py::none(),
|
||||
"k"_a = py::none(),
|
||||
"dtype"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Create an identity matrix or a general diagonal matrix.
|
||||
|
||||
Args:
|
||||
n (int): The number of rows in the output.
|
||||
m (int, optional): The number of columns in the output. Defaults to n.
|
||||
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
|
||||
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||
stream (Stream, optional): Stream or device. Defaults to None.
|
||||
|
||||
Returns:
|
||||
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"identity",
|
||||
&identity,
|
||||
"n"_a,
|
||||
"dtype"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Create a square identity matrix.
|
||||
|
||||
Args:
|
||||
n (int): The number of rows and columns in the output.
|
||||
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||
stream (Stream, optional): Stream or device. Defaults to None.
|
||||
|
||||
Returns:
|
||||
array: An identity matrix of size n x n.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"allclose",
|
||||
&allclose,
|
||||
|
@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual((a + b)[0, 0].item(), 2)
|
||||
|
||||
|
||||
def test_eye(self):
|
||||
eye_matrix = mx.eye(3)
|
||||
np_eye_matrix = np.eye(3)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test for non-square matrix
|
||||
eye_matrix = mx.eye(3, 4)
|
||||
np_eye_matrix = np.eye(3, 4)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test with positive k parameter
|
||||
eye_matrix = mx.eye(3, 4, k=1)
|
||||
np_eye_matrix = np.eye(3, 4, k=1)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
# Test with negative k parameter
|
||||
eye_matrix = mx.eye(5, 6, k=-2)
|
||||
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1926,3 +1926,58 @@ TEST_CASE("test where") {
|
||||
expected = array({1, 2, 2, 1}, {2, 2});
|
||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eye") {
|
||||
auto eye_3 = eye(3);
|
||||
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
|
||||
auto expected_eye_3 =
|
||||
array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
|
||||
CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
|
||||
|
||||
auto eye_3x2 = eye(3, 2);
|
||||
CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2});
|
||||
auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
|
||||
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test identity") {
|
||||
auto id_4 = identity(4);
|
||||
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
|
||||
auto expected_id_4 = array(
|
||||
{1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
1.0f},
|
||||
{4, 4});
|
||||
CHECK(array_equal(id_4, expected_id_4).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eye with positive k offset") {
|
||||
auto eye_3_k1 = eye(3, 4, 1);
|
||||
CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4});
|
||||
auto expected_eye_3_k1 = array(
|
||||
{0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{3, 4});
|
||||
CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eye with negative k offset") {
|
||||
auto eye_4_k_minus1 = eye(4, 3, -1);
|
||||
CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3});
|
||||
auto expected_eye_4_k_minus1 = array(
|
||||
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{4, 3});
|
||||
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||
}
|
Loading…
Reference in New Issue
Block a user