mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
added tri / tril / triu (#170)
* added tri / tril / triu * fixed tests * ctest tests * tri overload and simplified tests * changes from comment * more tests for m * ensure assert if not 2-D * remove broadcast_to * minor tweaks --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1410,6 +1410,72 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: An identity matrix of size n x n.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tri",
|
||||
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
|
||||
return tri(n, m.value_or(n), k, float32, s);
|
||||
},
|
||||
"n"_a,
|
||||
"m"_a = none,
|
||||
"k"_a = 0,
|
||||
"dtype"_a = float32,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An array with ones at and below the given diagonal and zeros elsewhere.
|
||||
|
||||
Args:
|
||||
n (int): The number of rows in the output.
|
||||
m (int, optional): The number of cols in the output. Defaults to ``None``.
|
||||
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
||||
dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
array: Array with its lower triangle filled with ones and zeros elsewhere
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tril",
|
||||
&tril,
|
||||
"x"_a,
|
||||
"k"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Zeros the array above the given diagonal.
|
||||
|
||||
Args:
|
||||
x (array): input array.
|
||||
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
array: Array zeroed above the given diagonal
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"triu",
|
||||
&triu,
|
||||
"x"_a,
|
||||
"k"_a = 0,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Zeros the array below the given diagonal.
|
||||
|
||||
Args:
|
||||
x (array): input array.
|
||||
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
array: Array zeroed below the given diagonal
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"allclose",
|
||||
&allclose,
|
||||
@@ -2254,7 +2320,7 @@ void init_ops(py::module_& m) {
|
||||
Args:
|
||||
arrays (list(array)): A list of arrays to stack.
|
||||
axis (int, optional): The axis in the result array along which the
|
||||
input arrays are stacked. Defaults to ``0``.
|
||||
input arrays are stacked. Defaults to ``0``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
|
Reference in New Issue
Block a user