mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 08:41:13 +08:00
Allow unary ops to accept array like (#1093)
This commit is contained in:
parent
cc05a281c4
commit
b21242faf1
@ -158,7 +158,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"abs",
|
"abs",
|
||||||
&mlx::core::abs,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::abs(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -175,7 +177,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sign",
|
"sign",
|
||||||
&sign,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return sign(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -192,7 +196,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"negative",
|
"negative",
|
||||||
&negative,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return negative(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -600,7 +606,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"square",
|
"square",
|
||||||
&square,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return square(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -617,7 +625,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sqrt",
|
"sqrt",
|
||||||
&mlx::core::sqrt,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::sqrt(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -634,7 +644,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"rsqrt",
|
"rsqrt",
|
||||||
&rsqrt,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return rsqrt(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -651,7 +663,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"reciprocal",
|
"reciprocal",
|
||||||
&reciprocal,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return reciprocal(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -757,7 +771,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"exp",
|
"exp",
|
||||||
&mlx::core::exp,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::exp(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -774,7 +790,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"expm1",
|
"expm1",
|
||||||
&mlx::core::expm1,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::expm1(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -793,7 +811,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"erf",
|
"erf",
|
||||||
&mlx::core::erf,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::erf(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -813,7 +833,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"erfinv",
|
"erfinv",
|
||||||
&mlx::core::erfinv,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::erfinv(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -830,7 +852,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sin",
|
"sin",
|
||||||
&mlx::core::sin,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::sin(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -847,7 +871,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cos",
|
"cos",
|
||||||
&mlx::core::cos,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::cos(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -864,7 +890,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tan",
|
"tan",
|
||||||
&mlx::core::tan,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::tan(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -881,7 +909,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"arcsin",
|
"arcsin",
|
||||||
&mlx::core::arcsin,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::arcsin(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -898,7 +928,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"arccos",
|
"arccos",
|
||||||
&mlx::core::arccos,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::arccos(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -915,7 +947,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"arctan",
|
"arctan",
|
||||||
&mlx::core::arctan,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::arctan(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -951,7 +985,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sinh",
|
"sinh",
|
||||||
&mlx::core::sinh,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::sinh(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -968,7 +1004,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cosh",
|
"cosh",
|
||||||
&mlx::core::cosh,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::cosh(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -985,7 +1023,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tanh",
|
"tanh",
|
||||||
&mlx::core::tanh,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::tanh(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1002,7 +1042,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"arcsinh",
|
"arcsinh",
|
||||||
&mlx::core::arcsinh,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::arcsinh(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1019,7 +1061,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"arccosh",
|
"arccosh",
|
||||||
&mlx::core::arccosh,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::arccosh(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1036,7 +1080,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"arctanh",
|
"arctanh",
|
||||||
&mlx::core::arctanh,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::arctanh(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1053,7 +1099,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"degrees",
|
"degrees",
|
||||||
&mlx::core::degrees,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return degrees(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1070,7 +1118,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"radians",
|
"radians",
|
||||||
&mlx::core::radians,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::radians(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1087,7 +1137,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"log",
|
"log",
|
||||||
&mlx::core::log,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::log(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1104,7 +1156,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"log2",
|
"log2",
|
||||||
&mlx::core::log2,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::log2(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1121,7 +1175,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"log10",
|
"log10",
|
||||||
&mlx::core::log10,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::log10(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1138,7 +1194,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"log1p",
|
"log1p",
|
||||||
&mlx::core::log1p,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::log1p(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1176,7 +1234,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sigmoid",
|
"sigmoid",
|
||||||
&sigmoid,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return sigmoid(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1837,7 +1897,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"floor",
|
"floor",
|
||||||
&mlx::core::floor,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::floor(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1854,7 +1916,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"ceil",
|
"ceil",
|
||||||
&mlx::core::ceil,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::ceil(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1871,7 +1935,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"isnan",
|
"isnan",
|
||||||
&mlx::core::isnan,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::isnan(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1888,7 +1954,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"isinf",
|
"isinf",
|
||||||
&mlx::core::isinf,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::isinf(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1905,7 +1973,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"isposinf",
|
"isposinf",
|
||||||
&isposinf,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return isposinf(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1923,7 +1993,9 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"isneginf",
|
"isneginf",
|
||||||
&isneginf,
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return isneginf(to_array(a), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -3409,8 +3481,8 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"round",
|
"round",
|
||||||
[](const array& a, int decimals, StreamOrDevice s) {
|
[](const ScalarOrArray& a, int decimals, StreamOrDevice s) {
|
||||||
return round(a, decimals, s);
|
return round(to_array(a), decimals, s);
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"decimals"_a = 0,
|
"decimals"_a = 0,
|
||||||
|
@ -1217,6 +1217,54 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
y_ = mx.array(x_)
|
y_ = mx.array(x_)
|
||||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||||
|
|
||||||
|
def test_unary_ops_from_non_array(self):
|
||||||
|
unary_ops = [
|
||||||
|
"abs",
|
||||||
|
"exp",
|
||||||
|
"log",
|
||||||
|
"square",
|
||||||
|
"sqrt",
|
||||||
|
"sin",
|
||||||
|
"cos",
|
||||||
|
"tan",
|
||||||
|
"sinh",
|
||||||
|
"cosh",
|
||||||
|
"tanh",
|
||||||
|
"sign",
|
||||||
|
"negative",
|
||||||
|
"expm1",
|
||||||
|
"arcsin",
|
||||||
|
"arccos",
|
||||||
|
"arctan",
|
||||||
|
"arcsinh",
|
||||||
|
"arctanh",
|
||||||
|
"degrees",
|
||||||
|
"radians",
|
||||||
|
"log2",
|
||||||
|
"log10",
|
||||||
|
"log1p",
|
||||||
|
"floor",
|
||||||
|
"ceil",
|
||||||
|
]
|
||||||
|
|
||||||
|
x = 0.5
|
||||||
|
x_np = np.random.rand(10).astype(np.float32)
|
||||||
|
for op in unary_ops:
|
||||||
|
with self.subTest(op=op):
|
||||||
|
# Test from scalar
|
||||||
|
expected = getattr(np, op)(x)
|
||||||
|
out = getattr(mx, op)(x)
|
||||||
|
|
||||||
|
# Check close
|
||||||
|
self.assertTrue(np.allclose(expected, out, equal_nan=True))
|
||||||
|
|
||||||
|
# Test from NumPy
|
||||||
|
expected = getattr(np, op)(x_np)
|
||||||
|
out = getattr(mx, op)(x_np)
|
||||||
|
|
||||||
|
# Check close
|
||||||
|
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
|
||||||
|
|
||||||
def test_trig_ops(self):
|
def test_trig_ops(self):
|
||||||
def test_ops(npop, mlxop, x, y, atol):
|
def test_ops(npop, mlxop, x, y, atol):
|
||||||
r_np = npop(x)
|
r_np = npop(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user