Allow unary ops to accept array like (#1093)

This commit is contained in:
Awni Hannun
2024-05-09 09:36:02 -07:00
committed by GitHub
parent cc05a281c4
commit b21242faf1
2 changed files with 162 additions and 42 deletions

View File

@@ -158,7 +158,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"abs",
&mlx::core::abs,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::abs(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -175,7 +177,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"sign",
&sign,
[](const ScalarOrArray& a, StreamOrDevice s) {
return sign(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -192,7 +196,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"negative",
&negative,
[](const ScalarOrArray& a, StreamOrDevice s) {
return negative(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -600,7 +606,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"square",
&square,
[](const ScalarOrArray& a, StreamOrDevice s) {
return square(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -617,7 +625,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"sqrt",
&mlx::core::sqrt,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::sqrt(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -634,7 +644,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"rsqrt",
&rsqrt,
[](const ScalarOrArray& a, StreamOrDevice s) {
return rsqrt(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -651,7 +663,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"reciprocal",
&reciprocal,
[](const ScalarOrArray& a, StreamOrDevice s) {
return reciprocal(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -757,7 +771,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"exp",
&mlx::core::exp,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::exp(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -774,7 +790,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"expm1",
&mlx::core::expm1,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::expm1(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -793,7 +811,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"erf",
&mlx::core::erf,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::erf(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -813,7 +833,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"erfinv",
&mlx::core::erfinv,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::erfinv(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -830,7 +852,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"sin",
&mlx::core::sin,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::sin(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -847,7 +871,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"cos",
&mlx::core::cos,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::cos(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -864,7 +890,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"tan",
&mlx::core::tan,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::tan(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -881,7 +909,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"arcsin",
&mlx::core::arcsin,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arcsin(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -898,7 +928,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"arccos",
&mlx::core::arccos,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arccos(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -915,7 +947,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"arctan",
&mlx::core::arctan,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arctan(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -951,7 +985,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"sinh",
&mlx::core::sinh,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::sinh(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -968,7 +1004,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"cosh",
&mlx::core::cosh,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::cosh(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -985,7 +1023,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"tanh",
&mlx::core::tanh,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::tanh(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1002,7 +1042,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"arcsinh",
&mlx::core::arcsinh,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arcsinh(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1019,7 +1061,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"arccosh",
&mlx::core::arccosh,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arccosh(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1036,7 +1080,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"arctanh",
&mlx::core::arctanh,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arctanh(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1053,7 +1099,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"degrees",
&mlx::core::degrees,
[](const ScalarOrArray& a, StreamOrDevice s) {
return degrees(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1070,7 +1118,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"radians",
&mlx::core::radians,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::radians(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1087,7 +1137,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"log",
&mlx::core::log,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1104,7 +1156,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"log2",
&mlx::core::log2,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log2(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1121,7 +1175,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"log10",
&mlx::core::log10,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log10(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1138,7 +1194,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"log1p",
&mlx::core::log1p,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log1p(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1176,7 +1234,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"sigmoid",
&sigmoid,
[](const ScalarOrArray& a, StreamOrDevice s) {
return sigmoid(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1837,7 +1897,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"floor",
&mlx::core::floor,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::floor(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1854,7 +1916,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"ceil",
&mlx::core::ceil,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::ceil(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1871,7 +1935,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"isnan",
&mlx::core::isnan,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::isnan(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1888,7 +1954,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"isinf",
&mlx::core::isinf,
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::isinf(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1905,7 +1973,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"isposinf",
&isposinf,
[](const ScalarOrArray& a, StreamOrDevice s) {
return isposinf(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1923,7 +1993,9 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"isneginf",
&isneginf,
[](const ScalarOrArray& a, StreamOrDevice s) {
return isneginf(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -3409,8 +3481,8 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"round",
[](const array& a, int decimals, StreamOrDevice s) {
return round(a, decimals, s);
[](const ScalarOrArray& a, int decimals, StreamOrDevice s) {
return round(to_array(a), decimals, s);
},
nb::arg(),
"decimals"_a = 0,
@@ -3750,15 +3822,15 @@ void init_ops(nb::module_& m) {
R"pbdoc(
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
Args: