Allow unary ops to accept array like (#1093)

This commit is contained in:
Awni Hannun 2024-05-09 09:36:02 -07:00 committed by GitHub
parent cc05a281c4
commit b21242faf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 162 additions and 42 deletions

View File

@ -158,7 +158,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"abs", "abs",
&mlx::core::abs, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::abs(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -175,7 +177,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"sign", "sign",
&sign, [](const ScalarOrArray& a, StreamOrDevice s) {
return sign(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -192,7 +196,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"negative", "negative",
&negative, [](const ScalarOrArray& a, StreamOrDevice s) {
return negative(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -600,7 +606,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"square", "square",
&square, [](const ScalarOrArray& a, StreamOrDevice s) {
return square(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -617,7 +625,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"sqrt", "sqrt",
&mlx::core::sqrt, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::sqrt(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -634,7 +644,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"rsqrt", "rsqrt",
&rsqrt, [](const ScalarOrArray& a, StreamOrDevice s) {
return rsqrt(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -651,7 +663,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"reciprocal", "reciprocal",
&reciprocal, [](const ScalarOrArray& a, StreamOrDevice s) {
return reciprocal(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -757,7 +771,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"exp", "exp",
&mlx::core::exp, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::exp(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -774,7 +790,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"expm1", "expm1",
&mlx::core::expm1, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::expm1(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -793,7 +811,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"erf", "erf",
&mlx::core::erf, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::erf(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -813,7 +833,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"erfinv", "erfinv",
&mlx::core::erfinv, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::erfinv(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -830,7 +852,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"sin", "sin",
&mlx::core::sin, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::sin(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -847,7 +871,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"cos", "cos",
&mlx::core::cos, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::cos(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -864,7 +890,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"tan", "tan",
&mlx::core::tan, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::tan(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -881,7 +909,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"arcsin", "arcsin",
&mlx::core::arcsin, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arcsin(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -898,7 +928,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"arccos", "arccos",
&mlx::core::arccos, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arccos(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -915,7 +947,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"arctan", "arctan",
&mlx::core::arctan, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arctan(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -951,7 +985,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"sinh", "sinh",
&mlx::core::sinh, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::sinh(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -968,7 +1004,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"cosh", "cosh",
&mlx::core::cosh, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::cosh(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -985,7 +1023,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"tanh", "tanh",
&mlx::core::tanh, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::tanh(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1002,7 +1042,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"arcsinh", "arcsinh",
&mlx::core::arcsinh, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arcsinh(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1019,7 +1061,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"arccosh", "arccosh",
&mlx::core::arccosh, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arccosh(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1036,7 +1080,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"arctanh", "arctanh",
&mlx::core::arctanh, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::arctanh(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1053,7 +1099,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"degrees", "degrees",
&mlx::core::degrees, [](const ScalarOrArray& a, StreamOrDevice s) {
return degrees(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1070,7 +1118,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"radians", "radians",
&mlx::core::radians, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::radians(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1087,7 +1137,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"log", "log",
&mlx::core::log, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1104,7 +1156,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"log2", "log2",
&mlx::core::log2, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log2(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1121,7 +1175,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"log10", "log10",
&mlx::core::log10, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log10(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1138,7 +1194,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"log1p", "log1p",
&mlx::core::log1p, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::log1p(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1176,7 +1234,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"sigmoid", "sigmoid",
&sigmoid, [](const ScalarOrArray& a, StreamOrDevice s) {
return sigmoid(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1837,7 +1897,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"floor", "floor",
&mlx::core::floor, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::floor(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1854,7 +1916,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"ceil", "ceil",
&mlx::core::ceil, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::ceil(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1871,7 +1935,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"isnan", "isnan",
&mlx::core::isnan, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::isnan(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1888,7 +1954,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"isinf", "isinf",
&mlx::core::isinf, [](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::isinf(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1905,7 +1973,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"isposinf", "isposinf",
&isposinf, [](const ScalarOrArray& a, StreamOrDevice s) {
return isposinf(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1923,7 +1993,9 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"isneginf", "isneginf",
&isneginf, [](const ScalarOrArray& a, StreamOrDevice s) {
return isneginf(to_array(a), s);
},
nb::arg(), nb::arg(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -3409,8 +3481,8 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"round", "round",
[](const array& a, int decimals, StreamOrDevice s) { [](const ScalarOrArray& a, int decimals, StreamOrDevice s) {
return round(a, decimals, s); return round(to_array(a), decimals, s);
}, },
nb::arg(), nb::arg(),
"decimals"_a = 0, "decimals"_a = 0,

View File

@ -1217,6 +1217,54 @@ class TestOps(mlx_tests.MLXTestCase):
y_ = mx.array(x_) y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
def test_unary_ops_from_non_array(self):
unary_ops = [
"abs",
"exp",
"log",
"square",
"sqrt",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"tanh",
"sign",
"negative",
"expm1",
"arcsin",
"arccos",
"arctan",
"arcsinh",
"arctanh",
"degrees",
"radians",
"log2",
"log10",
"log1p",
"floor",
"ceil",
]
x = 0.5
x_np = np.random.rand(10).astype(np.float32)
for op in unary_ops:
with self.subTest(op=op):
# Test from scalar
expected = getattr(np, op)(x)
out = getattr(mx, op)(x)
# Check close
self.assertTrue(np.allclose(expected, out, equal_nan=True))
# Test from NumPy
expected = getattr(np, op)(x_np)
out = getattr(mx, op)(x_np)
# Check close
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
def test_trig_ops(self): def test_trig_ops(self):
def test_ops(npop, mlxop, x, y, atol): def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x) r_np = npop(x)