mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-22 13:28:11 +08:00
@@ -1253,6 +1253,54 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The output array filled with ones.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eye",
|
||||
[](int n,
|
||||
py::object m_obj,
|
||||
py::object k_obj,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s) {
|
||||
int m = m_obj.is_none() ? n : m_obj.cast<int>();
|
||||
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
|
||||
return eye(n, m, k, dtype, s);
|
||||
},
|
||||
"n"_a,
|
||||
"m"_a = py::none(),
|
||||
"k"_a = py::none(),
|
||||
"dtype"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Create an identity matrix or a general diagonal matrix.
|
||||
|
||||
Args:
|
||||
n (int): The number of rows in the output.
|
||||
m (int, optional): The number of columns in the output. Defaults to n.
|
||||
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
|
||||
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||
stream (Stream, optional): Stream or device. Defaults to None.
|
||||
|
||||
Returns:
|
||||
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"identity",
|
||||
&identity,
|
||||
"n"_a,
|
||||
"dtype"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Create a square identity matrix.
|
||||
|
||||
Args:
|
||||
n (int): The number of rows and columns in the output.
|
||||
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
||||
stream (Stream, optional): Stream or device. Defaults to None.
|
||||
|
||||
Returns:
|
||||
array: An identity matrix of size n x n.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"allclose",
|
||||
&allclose,
|
||||
|
Reference in New Issue
Block a user