mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 04:11:12 +08:00
Fix CI format + build issue (#137)
* fix ci * Fix python bindings build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
3214629601
commit
b9226c367c
@ -1361,21 +1361,21 @@ void init_ops(py::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"eye",
|
"eye",
|
||||||
[](int n,
|
[](int n,
|
||||||
py::object m_obj,
|
std::optional<int> m,
|
||||||
py::object k_obj,
|
int k,
|
||||||
Dtype dtype,
|
std::optional<Dtype> dtype,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
int m = m_obj.is_none() ? n : m_obj.cast<int>();
|
return eye(n, m.value_or(n), k, dtype.value_or(float32), s);
|
||||||
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
|
|
||||||
return eye(n, m, k, dtype, s);
|
|
||||||
},
|
},
|
||||||
"n"_a,
|
"n"_a,
|
||||||
"m"_a = py::none(),
|
"m"_a = py::none(),
|
||||||
"k"_a = py::none(),
|
"k"_a = 0,
|
||||||
"dtype"_a = std::nullopt,
|
"dtype"_a = std::nullopt,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Create an identity matrix or a general diagonal matrix.
|
Create an identity matrix or a general diagonal matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1387,15 +1387,19 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"identity",
|
"identity",
|
||||||
&identity,
|
[](int n, std::optional<Dtype> dtype, StreamOrDevice s) {
|
||||||
|
return identity(n, dtype.value_or(float32), s);
|
||||||
|
},
|
||||||
"n"_a,
|
"n"_a,
|
||||||
"dtype"_a = std::nullopt,
|
"dtype"_a = std::nullopt,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Create a square identity matrix.
|
Create a square identity matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1405,7 +1409,7 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: An identity matrix of size n x n.
|
array: An identity matrix of size n x n.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"allclose",
|
"allclose",
|
||||||
&allclose,
|
&allclose,
|
||||||
@ -1918,13 +1922,13 @@ void init_ops(py::module_& m) {
|
|||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Returns a sorted copy of the array.
|
Returns a sorted copy of the array.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
axis (int or None, optional): Optional axis to sort over.
|
axis (int or None, optional): Optional axis to sort over.
|
||||||
If ``None``, this sorts over the flattened array.
|
If ``None``, this sorts over the flattened array.
|
||||||
If unspecified, it defaults to -1 (sorting over the last axis).
|
If unspecified, it defaults to -1 (sorting over the last axis).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1951,8 +1955,8 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
axis (int or None, optional): Optional axis to sort over.
|
axis (int or None, optional): Optional axis to sort over.
|
||||||
If ``None``, this sorts over the flattened array.
|
If ``None``, this sorts over the flattened array.
|
||||||
If unspecified, it defaults to -1 (sorting over the last axis).
|
If unspecified, it defaults to -1 (sorting over the last axis).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1983,12 +1987,12 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
kth (int): Element at the ``kth`` index will be in its sorted
|
kth (int): Element at the ``kth`` index will be in its sorted
|
||||||
position in the output. All elements before the kth index will
|
position in the output. All elements before the kth index will
|
||||||
be less or equal to the ``kth`` element and all elements after
|
be less or equal to the ``kth`` element and all elements after
|
||||||
will be greater or equal to the ``kth`` element in the output.
|
will be greater or equal to the ``kth`` element in the output.
|
||||||
axis (int or None, optional): Optional axis to partition over.
|
axis (int or None, optional): Optional axis to partition over.
|
||||||
If ``None``, this partitions over the flattened array.
|
If ``None``, this partitions over the flattened array.
|
||||||
If unspecified, it defaults to ``-1``.
|
If unspecified, it defaults to ``-1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -2021,11 +2025,11 @@ void init_ops(py::module_& m) {
|
|||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
kth (int): Element index at the ``kth`` position in the output will
|
kth (int): Element index at the ``kth`` position in the output will
|
||||||
give the sorted position. All indices before the ``kth`` position
|
give the sorted position. All indices before the ``kth`` position
|
||||||
will be of elements less or equal to the element at the ``kth``
|
will be of elements less or equal to the element at the ``kth``
|
||||||
index and all indices after will be of elements greater or equal
|
index and all indices after will be of elements greater or equal
|
||||||
to the element at the ``kth`` index.
|
to the element at the ``kth`` index.
|
||||||
axis (int or None, optional): Optional axis to partiton over.
|
axis (int or None, optional): Optional axis to partiton over.
|
||||||
If ``None``, this partitions over the flattened array.
|
If ``None``, this partitions over the flattened array.
|
||||||
If unspecified, it defaults to ``-1``.
|
If unspecified, it defaults to ``-1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -2056,8 +2060,8 @@ void init_ops(py::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
k (int): ``k`` top elements to be returned
|
k (int): ``k`` top elements to be returned
|
||||||
axis (int or None, optional): Optional axis to select over.
|
axis (int or None, optional): Optional axis to select over.
|
||||||
If ``None``, this selects the top ``k`` elements over the
|
If ``None``, this selects the top ``k`` elements over the
|
||||||
flattened array. If unspecified, it defaults to ``-1``.
|
flattened array. If unspecified, it defaults to ``-1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -2539,14 +2543,14 @@ void init_ops(py::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
input (array): input array of shape ``(N, H, W, C_in)``
|
input (array): input array of shape ``(N, H, W, C_in)``
|
||||||
weight (array): weight array of shape ``(C_out, H, W, C_in)``
|
weight (array): weight array of shape ``(C_out, H, W, C_in)``
|
||||||
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||||
kernel strides. All spatial dimensions get the same stride if
|
kernel strides. All spatial dimensions get the same stride if
|
||||||
only one number is specified. Default: ``1``.
|
only one number is specified. Default: ``1``.
|
||||||
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||||
symmetric input padding. All spatial dimensions get the same
|
symmetric input padding. All spatial dimensions get the same
|
||||||
padding if only one number is specified. Default: ``0``.
|
padding if only one number is specified. Default: ``0``.
|
||||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||||
kernel dilation. All spatial dimensions get the same dilation
|
kernel dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
groups (int, optional): input feature groups. Default: ``1``.
|
groups (int, optional): input feature groups. Default: ``1``.
|
||||||
|
|
||||||
@ -2583,7 +2587,7 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
savez(file: str, *args, **kwargs)
|
savez(file: str, *args, **kwargs)
|
||||||
|
|
||||||
Save several arrays to a binary file in uncompressed ``.npz`` format.
|
Save several arrays to a binary file in uncompressed ``.npz`` format.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1310,7 +1310,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
b = mx.ones([2147484], mx.int8)
|
b = mx.ones([2147484], mx.int8)
|
||||||
self.assertEqual((a + b)[0, 0].item(), 2)
|
self.assertEqual((a + b)[0, 0].item(), 2)
|
||||||
|
|
||||||
|
|
||||||
def test_eye(self):
|
def test_eye(self):
|
||||||
eye_matrix = mx.eye(3)
|
eye_matrix = mx.eye(3)
|
||||||
np_eye_matrix = np.eye(3)
|
np_eye_matrix = np.eye(3)
|
||||||
@ -1331,8 +1330,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
np_eye_matrix = np.eye(5, 6, k=-2)
|
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user