From b9226c367c515eaab1614d16ff72b8d71ba1f933 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 11 Dec 2023 15:01:41 -0800 Subject: [PATCH] Fix CI format + build issue (#137) * fix ci * Fix python bindings build --------- Co-authored-by: Angelos Katharopoulos --- python/src/ops.cpp | 66 +++++++++++++++++++++------------------- python/tests/test_ops.py | 3 -- 2 files changed, 35 insertions(+), 34 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index b9eacea98..a5ae163f4 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1361,21 +1361,21 @@ void init_ops(py::module_& m) { m.def( "eye", [](int n, - py::object m_obj, - py::object k_obj, - Dtype dtype, + std::optional m, + int k, + std::optional dtype, StreamOrDevice s) { - int m = m_obj.is_none() ? n : m_obj.cast(); - int k = k_obj.is_none() ? 0 : k_obj.cast(); - return eye(n, m, k, dtype, s); + return eye(n, m.value_or(n), k, dtype.value_or(float32), s); }, "n"_a, "m"_a = py::none(), - "k"_a = py::none(), + "k"_a = 0, "dtype"_a = std::nullopt, py::kw_only(), "stream"_a = none, R"pbdoc( + eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + Create an identity matrix or a general diagonal matrix. Args: @@ -1387,15 +1387,19 @@ void init_ops(py::module_& m) { Returns: array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. - )pbdoc"); + )pbdoc"); m.def( "identity", - &identity, + [](int n, std::optional dtype, StreamOrDevice s) { + return identity(n, dtype.value_or(float32), s); + }, "n"_a, "dtype"_a = std::nullopt, py::kw_only(), "stream"_a = none, R"pbdoc( + identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + Create a square identity matrix. Args: @@ -1405,7 +1409,7 @@ void init_ops(py::module_& m) { Returns: array: An identity matrix of size n x n. - )pbdoc"); + )pbdoc"); m.def( "allclose", &allclose, @@ -1918,13 +1922,13 @@ void init_ops(py::module_& m) { "stream"_a = none, R"pbdoc( sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array - + Returns a sorted copy of the array. Args: a (array): Input array. - axis (int or None, optional): Optional axis to sort over. - If ``None``, this sorts over the flattened array. + axis (int or None, optional): Optional axis to sort over. + If ``None``, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis). Returns: @@ -1951,8 +1955,8 @@ void init_ops(py::module_& m) { Args: a (array): Input array. - axis (int or None, optional): Optional axis to sort over. - If ``None``, this sorts over the flattened array. + axis (int or None, optional): Optional axis to sort over. + If ``None``, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis). Returns: @@ -1983,12 +1987,12 @@ void init_ops(py::module_& m) { Args: a (array): Input array. - kth (int): Element at the ``kth`` index will be in its sorted - position in the output. All elements before the kth index will - be less or equal to the ``kth`` element and all elements after + kth (int): Element at the ``kth`` index will be in its sorted + position in the output. All elements before the kth index will + be less or equal to the ``kth`` element and all elements after will be greater or equal to the ``kth`` element in the output. - axis (int or None, optional): Optional axis to partition over. - If ``None``, this partitions over the flattened array. + axis (int or None, optional): Optional axis to partition over. + If ``None``, this partitions over the flattened array. If unspecified, it defaults to ``-1``. Returns: @@ -2021,11 +2025,11 @@ void init_ops(py::module_& m) { a (array): Input array. kth (int): Element index at the ``kth`` position in the output will give the sorted position. All indices before the ``kth`` position - will be of elements less or equal to the element at the ``kth`` + will be of elements less or equal to the element at the ``kth`` index and all indices after will be of elements greater or equal to the element at the ``kth`` index. - axis (int or None, optional): Optional axis to partiton over. - If ``None``, this partitions over the flattened array. + axis (int or None, optional): Optional axis to partiton over. + If ``None``, this partitions over the flattened array. If unspecified, it defaults to ``-1``. Returns: @@ -2056,8 +2060,8 @@ void init_ops(py::module_& m) { Args: a (array): Input array. k (int): ``k`` top elements to be returned - axis (int or None, optional): Optional axis to select over. - If ``None``, this selects the top ``k`` elements over the + axis (int or None, optional): Optional axis to select over. + If ``None``, this selects the top ``k`` elements over the flattened array. If unspecified, it defaults to ``-1``. Returns: @@ -2539,14 +2543,14 @@ void init_ops(py::module_& m) { Args: input (array): input array of shape ``(N, H, W, C_in)`` weight (array): weight array of shape ``(C_out, H, W, C_in)`` - stride (int or tuple(int), optional): :obj:`tuple` of size 2 with - kernel strides. All spatial dimensions get the same stride if + stride (int or tuple(int), optional): :obj:`tuple` of size 2 with + kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. padding (int or tuple(int), optional): :obj:`tuple` of size 2 with - symmetric input padding. All spatial dimensions get the same + symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: ``0``. - dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with - kernel dilation. All spatial dimensions get the same dilation + dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with + kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` groups (int, optional): input feature groups. Default: ``1``. @@ -2583,7 +2587,7 @@ void init_ops(py::module_& m) { py::kw_only(), R"pbdoc( savez(file: str, *args, **kwargs) - + Save several arrays to a binary file in uncompressed ``.npz`` format. .. code-block:: python diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 57c6ea2b0..0ee234cf4 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1310,7 +1310,6 @@ class TestOps(mlx_tests.MLXTestCase): b = mx.ones([2147484], mx.int8) self.assertEqual((a + b)[0, 0].item(), 2) - def test_eye(self): eye_matrix = mx.eye(3) np_eye_matrix = np.eye(3) @@ -1331,8 +1330,6 @@ class TestOps(mlx_tests.MLXTestCase): np_eye_matrix = np.eye(5, 6, k=-2) self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) - - if __name__ == "__main__": unittest.main()