mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix CI format + build issue (#137)
* fix ci * Fix python bindings build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -1361,21 +1361,21 @@ void init_ops(py::module_& m) { | ||||
|   m.def( | ||||
|       "eye", | ||||
|       [](int n, | ||||
|          py::object m_obj, | ||||
|          py::object k_obj, | ||||
|          Dtype dtype, | ||||
|          std::optional<int> m, | ||||
|          int k, | ||||
|          std::optional<Dtype> dtype, | ||||
|          StreamOrDevice s) { | ||||
|         int m = m_obj.is_none() ? n : m_obj.cast<int>(); | ||||
|         int k = k_obj.is_none() ? 0 : k_obj.cast<int>(); | ||||
|         return eye(n, m, k, dtype, s); | ||||
|         return eye(n, m.value_or(n), k, dtype.value_or(float32), s); | ||||
|       }, | ||||
|       "n"_a, | ||||
|       "m"_a = py::none(), | ||||
|       "k"_a = py::none(), | ||||
|       "k"_a = 0, | ||||
|       "dtype"_a = std::nullopt, | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|       eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|       Create an identity matrix or a general diagonal matrix. | ||||
|  | ||||
|       Args: | ||||
| @@ -1387,15 +1387,19 @@ void init_ops(py::module_& m) { | ||||
|  | ||||
|       Returns: | ||||
|           array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. | ||||
|     )pbdoc"); | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "identity", | ||||
|       &identity, | ||||
|       [](int n, std::optional<Dtype> dtype, StreamOrDevice s) { | ||||
|         return identity(n, dtype.value_or(float32), s); | ||||
|       }, | ||||
|       "n"_a, | ||||
|       "dtype"_a = std::nullopt, | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|       identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|       Create a square identity matrix. | ||||
|  | ||||
|       Args: | ||||
| @@ -1405,7 +1409,7 @@ void init_ops(py::module_& m) { | ||||
|  | ||||
|       Returns: | ||||
|           array: An identity matrix of size n x n. | ||||
|     )pbdoc"); | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "allclose", | ||||
|       &allclose, | ||||
| @@ -1918,13 +1922,13 @@ void init_ops(py::module_& m) { | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|          | ||||
|  | ||||
|         Returns a sorted copy of the array. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             axis (int or None, optional): Optional axis to sort over.  | ||||
|               If ``None``, this sorts over the flattened array.  | ||||
|             axis (int or None, optional): Optional axis to sort over. | ||||
|               If ``None``, this sorts over the flattened array. | ||||
|               If unspecified, it defaults to -1 (sorting over the last axis). | ||||
|  | ||||
|         Returns: | ||||
| @@ -1951,8 +1955,8 @@ void init_ops(py::module_& m) { | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             axis (int or None, optional): Optional axis to sort over.  | ||||
|               If ``None``, this sorts over the flattened array.  | ||||
|             axis (int or None, optional): Optional axis to sort over. | ||||
|               If ``None``, this sorts over the flattened array. | ||||
|               If unspecified, it defaults to -1 (sorting over the last axis). | ||||
|  | ||||
|         Returns: | ||||
| @@ -1983,12 +1987,12 @@ void init_ops(py::module_& m) { | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             kth (int): Element at the ``kth`` index will be in its sorted  | ||||
|               position in the output. All elements before the kth index will  | ||||
|               be less or equal to the ``kth`` element and all elements after  | ||||
|             kth (int): Element at the ``kth`` index will be in its sorted | ||||
|               position in the output. All elements before the kth index will | ||||
|               be less or equal to the ``kth`` element and all elements after | ||||
|               will be greater or equal to the ``kth`` element in the output. | ||||
|             axis (int or None, optional): Optional axis to partition over.  | ||||
|               If ``None``, this partitions over the flattened array.  | ||||
|             axis (int or None, optional): Optional axis to partition over. | ||||
|               If ``None``, this partitions over the flattened array. | ||||
|               If unspecified, it defaults to ``-1``. | ||||
|  | ||||
|         Returns: | ||||
| @@ -2021,11 +2025,11 @@ void init_ops(py::module_& m) { | ||||
|             a (array): Input array. | ||||
|             kth (int): Element index at the ``kth`` position in the output will | ||||
|               give the sorted position. All indices before the ``kth`` position | ||||
|               will be of elements less or equal to the element at the ``kth``  | ||||
|               will be of elements less or equal to the element at the ``kth`` | ||||
|               index and all indices after will be of elements greater or equal | ||||
|               to the element at the ``kth`` index. | ||||
|             axis (int or None, optional): Optional axis to partiton over.  | ||||
|               If ``None``, this partitions over the flattened array.  | ||||
|             axis (int or None, optional): Optional axis to partiton over. | ||||
|               If ``None``, this partitions over the flattened array. | ||||
|               If unspecified, it defaults to ``-1``. | ||||
|  | ||||
|         Returns: | ||||
| @@ -2056,8 +2060,8 @@ void init_ops(py::module_& m) { | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             k (int): ``k`` top elements to be returned | ||||
|             axis (int or None, optional): Optional axis to select over.  | ||||
|               If ``None``, this selects the top ``k`` elements over the  | ||||
|             axis (int or None, optional): Optional axis to select over. | ||||
|               If ``None``, this selects the top ``k`` elements over the | ||||
|               flattened array. If unspecified, it defaults to ``-1``. | ||||
|  | ||||
|         Returns: | ||||
| @@ -2539,14 +2543,14 @@ void init_ops(py::module_& m) { | ||||
|         Args: | ||||
|             input (array): input array of shape ``(N, H, W, C_in)`` | ||||
|             weight (array): weight array of shape ``(C_out, H, W, C_in)`` | ||||
|             stride (int or tuple(int), optional): :obj:`tuple` of size 2 with  | ||||
|                 kernel strides. All spatial dimensions get the same stride if  | ||||
|             stride (int or tuple(int), optional): :obj:`tuple` of size 2 with | ||||
|                 kernel strides. All spatial dimensions get the same stride if | ||||
|                 only one number is specified. Default: ``1``. | ||||
|             padding (int or tuple(int), optional): :obj:`tuple` of size 2 with | ||||
|                 symmetric input padding. All spatial dimensions get the same  | ||||
|                 symmetric input padding. All spatial dimensions get the same | ||||
|                 padding if only one number is specified. Default: ``0``. | ||||
|             dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with  | ||||
|                 kernel dilation. All spatial dimensions get the same dilation  | ||||
|             dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with | ||||
|                 kernel dilation. All spatial dimensions get the same dilation | ||||
|                 if only one number is specified. Default: ``1`` | ||||
|             groups (int, optional): input feature groups. Default: ``1``. | ||||
|  | ||||
| @@ -2583,7 +2587,7 @@ void init_ops(py::module_& m) { | ||||
|       py::kw_only(), | ||||
|       R"pbdoc( | ||||
|         savez(file: str, *args, **kwargs) | ||||
|          | ||||
|  | ||||
|         Save several arrays to a binary file in uncompressed ``.npz`` format. | ||||
|  | ||||
|         .. code-block:: python | ||||
|   | ||||
| @@ -1310,7 +1310,6 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         b = mx.ones([2147484], mx.int8) | ||||
|         self.assertEqual((a + b)[0, 0].item(), 2) | ||||
|  | ||||
|  | ||||
|     def test_eye(self): | ||||
|         eye_matrix = mx.eye(3) | ||||
|         np_eye_matrix = np.eye(3) | ||||
| @@ -1331,8 +1330,6 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         np_eye_matrix = np.eye(5, 6, k=-2) | ||||
|         self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) | ||||
|  | ||||
|      | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun