Fix CI format + build issue (#137)

* fix ci

* Fix python bindings build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2023-12-11 15:01:41 -08:00 committed by GitHub
parent 3214629601
commit b9226c367c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 34 deletions

View File

@ -1361,21 +1361,21 @@ void init_ops(py::module_& m) {
m.def(
"eye",
[](int n,
py::object m_obj,
py::object k_obj,
Dtype dtype,
std::optional<int> m,
int k,
std::optional<Dtype> dtype,
StreamOrDevice s) {
int m = m_obj.is_none() ? n : m_obj.cast<int>();
int k = k_obj.is_none() ? 0 : k_obj.cast<int>();
return eye(n, m, k, dtype, s);
return eye(n, m.value_or(n), k, dtype.value_or(float32), s);
},
"n"_a,
"m"_a = py::none(),
"k"_a = py::none(),
"k"_a = 0,
"dtype"_a = std::nullopt,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Create an identity matrix or a general diagonal matrix.
Args:
@ -1387,15 +1387,19 @@ void init_ops(py::module_& m) {
Returns:
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
)pbdoc");
)pbdoc");
m.def(
"identity",
&identity,
[](int n, std::optional<Dtype> dtype, StreamOrDevice s) {
return identity(n, dtype.value_or(float32), s);
},
"n"_a,
"dtype"_a = std::nullopt,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Create a square identity matrix.
Args:
@ -1405,7 +1409,7 @@ void init_ops(py::module_& m) {
Returns:
array: An identity matrix of size n x n.
)pbdoc");
)pbdoc");
m.def(
"allclose",
&allclose,

View File

@ -1310,7 +1310,6 @@ class TestOps(mlx_tests.MLXTestCase):
b = mx.ones([2147484], mx.int8)
self.assertEqual((a + b)[0, 0].item(), 2)
def test_eye(self):
eye_matrix = mx.eye(3)
np_eye_matrix = np.eye(3)
@ -1332,7 +1331,5 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
if __name__ == "__main__":
unittest.main()