mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
69505b4e9b
commit
e080290ba4
@ -38,9 +38,11 @@ Operations
|
|||||||
erfinv
|
erfinv
|
||||||
exp
|
exp
|
||||||
expand_dims
|
expand_dims
|
||||||
|
eye
|
||||||
full
|
full
|
||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
|
identity
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
load
|
load
|
||||||
|
24
mlx/ops.cpp
24
mlx/ops.cpp
@ -194,6 +194,30 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return ones(a.shape(), a.dtype(), to_stream(s));
|
return ones(a.shape(), a.dtype(), to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
|
if (n <= 0 || m <= 0) {
|
||||||
|
throw std::invalid_argument("N and M must be positive integers.");
|
||||||
|
}
|
||||||
|
array result = zeros({n * m}, dtype, s);
|
||||||
|
if (k >= m || -k >= n) {
|
||||||
|
return reshape(result, {n, m}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
|
||||||
|
int start_index = (k >= 0) ? k : -k * m;
|
||||||
|
|
||||||
|
array diag_indices_array = arange(
|
||||||
|
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
|
||||||
|
array ones_array = ones({diagonal_length, 1}, dtype, s);
|
||||||
|
result = scatter(result, diag_indices_array, ones_array, 0, s);
|
||||||
|
|
||||||
|
return reshape(result, {n, m}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
|
return eye(n, n, 0, dtype, s);
|
||||||
|
}
|
||||||
|
|
||||||
array reshape(
|
array reshape(
|
||||||
const array& a,
|
const array& a,
|
||||||
std::vector<int> shape,
|
std::vector<int> shape,
|
||||||
|
23
mlx/ops.h
23
mlx/ops.h
@ -87,6 +87,29 @@ inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
|||||||
}
|
}
|
||||||
array ones_like(const array& a, StreamOrDevice s = {});
|
array ones_like(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
|
||||||
|
* k, and zeros everywhere else. */
|
||||||
|
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
|
||||||
|
return eye(n, n, 0, dtype, s);
|
||||||
|
}
|
||||||
|
inline array eye(int n, int m, StreamOrDevice s = {}) {
|
||||||
|
return eye(n, m, 0, float32, s);
|
||||||
|
}
|
||||||
|
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
|
||||||
|
return eye(n, m, k, float32, s);
|
||||||
|
}
|
||||||
|
inline array eye(int n, StreamOrDevice s = {}) {
|
||||||
|
return eye(n, n, 0, float32, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
|
||||||
|
* diagonal. */
|
||||||
|
array identity(int n, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
inline array identity(int n, StreamOrDevice s = {}) {
|
||||||
|
return identity(n, float32, s);
|
||||||
|
}
|
||||||
|
|
||||||
/** array manipulation */
|
/** array manipulation */
|
||||||
|
|
||||||
/** Reshape an array to the given shape. */
|
/** Reshape an array to the given shape. */
|
||||||
|
@ -1253,6 +1253,54 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The output array filled with ones.
|
array: The output array filled with ones.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"eye",
|
||||||
|
[](int n,
|
||||||
|
py::object m_obj,
|
||||||
|
py::object k_obj,
|
||||||
|
Dtype dtype,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
int m = m_obj.is_none() ? n : m_obj.cast<int>();
|
||||||
|
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
|
||||||
|
return eye(n, m, k, dtype, s);
|
||||||
|
},
|
||||||
|
"n"_a,
|
||||||
|
"m"_a = py::none(),
|
||||||
|
"k"_a = py::none(),
|
||||||
|
"dtype"_a = std::nullopt,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
Create an identity matrix or a general diagonal matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of rows in the output.
|
||||||
|
m (int, optional): The number of columns in the output. Defaults to n.
|
||||||
|
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
|
||||||
|
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"identity",
|
||||||
|
&identity,
|
||||||
|
"n"_a,
|
||||||
|
"dtype"_a = std::nullopt,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
Create a square identity matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of rows and columns in the output.
|
||||||
|
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: An identity matrix of size n x n.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"allclose",
|
"allclose",
|
||||||
&allclose,
|
&allclose,
|
||||||
|
@ -1311,5 +1311,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual((a + b)[0, 0].item(), 2)
|
self.assertEqual((a + b)[0, 0].item(), 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eye(self):
|
||||||
|
eye_matrix = mx.eye(3)
|
||||||
|
np_eye_matrix = np.eye(3)
|
||||||
|
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||||
|
|
||||||
|
# Test for non-square matrix
|
||||||
|
eye_matrix = mx.eye(3, 4)
|
||||||
|
np_eye_matrix = np.eye(3, 4)
|
||||||
|
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||||
|
|
||||||
|
# Test with positive k parameter
|
||||||
|
eye_matrix = mx.eye(3, 4, k=1)
|
||||||
|
np_eye_matrix = np.eye(3, 4, k=1)
|
||||||
|
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||||
|
|
||||||
|
# Test with negative k parameter
|
||||||
|
eye_matrix = mx.eye(5, 6, k=-2)
|
||||||
|
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||||
|
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1926,3 +1926,58 @@ TEST_CASE("test where") {
|
|||||||
expected = array({1, 2, 2, 1}, {2, 2});
|
expected = array({1, 2, 2, 1}, {2, 2});
|
||||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test eye") {
|
||||||
|
auto eye_3 = eye(3);
|
||||||
|
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
|
||||||
|
auto expected_eye_3 =
|
||||||
|
array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
|
||||||
|
CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
|
||||||
|
|
||||||
|
auto eye_3x2 = eye(3, 2);
|
||||||
|
CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2});
|
||||||
|
auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
|
||||||
|
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test identity") {
|
||||||
|
auto id_4 = identity(4);
|
||||||
|
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
|
||||||
|
auto expected_id_4 = array(
|
||||||
|
{1.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
1.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
1.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
1.0f},
|
||||||
|
{4, 4});
|
||||||
|
CHECK(array_equal(id_4, expected_id_4).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test eye with positive k offset") {
|
||||||
|
auto eye_3_k1 = eye(3, 4, 1);
|
||||||
|
CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4});
|
||||||
|
auto expected_eye_3_k1 = array(
|
||||||
|
{0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||||
|
{3, 4});
|
||||||
|
CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test eye with negative k offset") {
|
||||||
|
auto eye_4_k_minus1 = eye(4, 3, -1);
|
||||||
|
CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3});
|
||||||
|
auto expected_eye_4_k_minus1 = array(
|
||||||
|
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||||
|
{4, 3});
|
||||||
|
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user