Fix CI format + build issue (#137)

* fix ci

* Fix python bindings build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2023-12-11 15:01:41 -08:00 committed by GitHub
parent 3214629601
commit b9226c367c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 34 deletions

View File

@ -1361,21 +1361,21 @@ void init_ops(py::module_& m) {
m.def( m.def(
"eye", "eye",
[](int n, [](int n,
py::object m_obj, std::optional<int> m,
py::object k_obj, int k,
Dtype dtype, std::optional<Dtype> dtype,
StreamOrDevice s) { StreamOrDevice s) {
int m = m_obj.is_none() ? n : m_obj.cast<int>(); return eye(n, m.value_or(n), k, dtype.value_or(float32), s);
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
return eye(n, m, k, dtype, s);
}, },
"n"_a, "n"_a,
"m"_a = py::none(), "m"_a = py::none(),
"k"_a = py::none(), "k"_a = 0,
"dtype"_a = std::nullopt, "dtype"_a = std::nullopt,
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Create an identity matrix or a general diagonal matrix. Create an identity matrix or a general diagonal matrix.
Args: Args:
@ -1387,15 +1387,19 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
)pbdoc"); )pbdoc");
m.def( m.def(
"identity", "identity",
&identity, [](int n, std::optional<Dtype> dtype, StreamOrDevice s) {
return identity(n, dtype.value_or(float32), s);
},
"n"_a, "n"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::nullopt,
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Create a square identity matrix. Create a square identity matrix.
Args: Args:
@ -1405,7 +1409,7 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: An identity matrix of size n x n. array: An identity matrix of size n x n.
)pbdoc"); )pbdoc");
m.def( m.def(
"allclose", "allclose",
&allclose, &allclose,
@ -1918,13 +1922,13 @@ void init_ops(py::module_& m) {
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns a sorted copy of the array. Returns a sorted copy of the array.
Args: Args:
a (array): Input array. a (array): Input array.
axis (int or None, optional): Optional axis to sort over. axis (int or None, optional): Optional axis to sort over.
If ``None``, this sorts over the flattened array. If ``None``, this sorts over the flattened array.
If unspecified, it defaults to -1 (sorting over the last axis). If unspecified, it defaults to -1 (sorting over the last axis).
Returns: Returns:
@ -1951,8 +1955,8 @@ void init_ops(py::module_& m) {
Args: Args:
a (array): Input array. a (array): Input array.
axis (int or None, optional): Optional axis to sort over. axis (int or None, optional): Optional axis to sort over.
If ``None``, this sorts over the flattened array. If ``None``, this sorts over the flattened array.
If unspecified, it defaults to -1 (sorting over the last axis). If unspecified, it defaults to -1 (sorting over the last axis).
Returns: Returns:
@ -1983,12 +1987,12 @@ void init_ops(py::module_& m) {
Args: Args:
a (array): Input array. a (array): Input array.
kth (int): Element at the ``kth`` index will be in its sorted kth (int): Element at the ``kth`` index will be in its sorted
position in the output. All elements before the kth index will position in the output. All elements before the kth index will
be less or equal to the ``kth`` element and all elements after be less or equal to the ``kth`` element and all elements after
will be greater or equal to the ``kth`` element in the output. will be greater or equal to the ``kth`` element in the output.
axis (int or None, optional): Optional axis to partition over. axis (int or None, optional): Optional axis to partition over.
If ``None``, this partitions over the flattened array. If ``None``, this partitions over the flattened array.
If unspecified, it defaults to ``-1``. If unspecified, it defaults to ``-1``.
Returns: Returns:
@ -2021,11 +2025,11 @@ void init_ops(py::module_& m) {
a (array): Input array. a (array): Input array.
kth (int): Element index at the ``kth`` position in the output will kth (int): Element index at the ``kth`` position in the output will
give the sorted position. All indices before the ``kth`` position give the sorted position. All indices before the ``kth`` position
will be of elements less or equal to the element at the ``kth`` will be of elements less or equal to the element at the ``kth``
index and all indices after will be of elements greater or equal index and all indices after will be of elements greater or equal
to the element at the ``kth`` index. to the element at the ``kth`` index.
axis (int or None, optional): Optional axis to partiton over. axis (int or None, optional): Optional axis to partiton over.
If ``None``, this partitions over the flattened array. If ``None``, this partitions over the flattened array.
If unspecified, it defaults to ``-1``. If unspecified, it defaults to ``-1``.
Returns: Returns:
@ -2056,8 +2060,8 @@ void init_ops(py::module_& m) {
Args: Args:
a (array): Input array. a (array): Input array.
k (int): ``k`` top elements to be returned k (int): ``k`` top elements to be returned
axis (int or None, optional): Optional axis to select over. axis (int or None, optional): Optional axis to select over.
If ``None``, this selects the top ``k`` elements over the If ``None``, this selects the top ``k`` elements over the
flattened array. If unspecified, it defaults to ``-1``. flattened array. If unspecified, it defaults to ``-1``.
Returns: Returns:
@ -2539,14 +2543,14 @@ void init_ops(py::module_& m) {
Args: Args:
input (array): input array of shape ``(N, H, W, C_in)`` input (array): input array of shape ``(N, H, W, C_in)``
weight (array): weight array of shape ``(C_out, H, W, C_in)`` weight (array): weight array of shape ``(C_out, H, W, C_in)``
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel strides. All spatial dimensions get the same stride if kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``. only one number is specified. Default: ``1``.
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
symmetric input padding. All spatial dimensions get the same symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``. padding if only one number is specified. Default: ``0``.
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel dilation. All spatial dimensions get the same dilation kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1`` if only one number is specified. Default: ``1``
groups (int, optional): input feature groups. Default: ``1``. groups (int, optional): input feature groups. Default: ``1``.
@ -2583,7 +2587,7 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
R"pbdoc( R"pbdoc(
savez(file: str, *args, **kwargs) savez(file: str, *args, **kwargs)
Save several arrays to a binary file in uncompressed ``.npz`` format. Save several arrays to a binary file in uncompressed ``.npz`` format.
.. code-block:: python .. code-block:: python

View File

@ -1310,7 +1310,6 @@ class TestOps(mlx_tests.MLXTestCase):
b = mx.ones([2147484], mx.int8) b = mx.ones([2147484], mx.int8)
self.assertEqual((a + b)[0, 0].item(), 2) self.assertEqual((a + b)[0, 0].item(), 2)
def test_eye(self): def test_eye(self):
eye_matrix = mx.eye(3) eye_matrix = mx.eye(3)
np_eye_matrix = np.eye(3) np_eye_matrix = np.eye(3)
@ -1331,8 +1330,6 @@ class TestOps(mlx_tests.MLXTestCase):
np_eye_matrix = np.eye(5, 6, k=-2) np_eye_matrix = np.eye(5, 6, k=-2)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()