diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4857f813d..8c8caf06f 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -158,7 +158,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "abs", - &mlx::core::abs, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::abs(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -175,7 +177,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sign", - &sign, + [](const ScalarOrArray& a, StreamOrDevice s) { + return sign(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -192,7 +196,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "negative", - &negative, + [](const ScalarOrArray& a, StreamOrDevice s) { + return negative(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -600,7 +606,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "square", - &square, + [](const ScalarOrArray& a, StreamOrDevice s) { + return square(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -617,7 +625,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sqrt", - &mlx::core::sqrt, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::sqrt(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -634,7 +644,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "rsqrt", - &rsqrt, + [](const ScalarOrArray& a, StreamOrDevice s) { + return rsqrt(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -651,7 +663,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "reciprocal", - &reciprocal, + [](const ScalarOrArray& a, StreamOrDevice s) { + return reciprocal(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -757,7 +771,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "exp", - &mlx::core::exp, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::exp(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -774,7 +790,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "expm1", - &mlx::core::expm1, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::expm1(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -793,7 +811,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "erf", - &mlx::core::erf, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::erf(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -813,7 +833,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "erfinv", - &mlx::core::erfinv, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::erfinv(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -830,7 +852,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sin", - &mlx::core::sin, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::sin(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -847,7 +871,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cos", - &mlx::core::cos, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::cos(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -864,7 +890,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "tan", - &mlx::core::tan, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::tan(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -881,7 +909,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arcsin", - &mlx::core::arcsin, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::arcsin(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -898,7 +928,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arccos", - &mlx::core::arccos, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::arccos(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -915,7 +947,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arctan", - &mlx::core::arctan, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::arctan(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -951,7 +985,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sinh", - &mlx::core::sinh, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::sinh(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -968,7 +1004,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cosh", - &mlx::core::cosh, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::cosh(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -985,7 +1023,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "tanh", - &mlx::core::tanh, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::tanh(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1002,7 +1042,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arcsinh", - &mlx::core::arcsinh, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::arcsinh(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1019,7 +1061,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arccosh", - &mlx::core::arccosh, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::arccosh(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1036,7 +1080,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arctanh", - &mlx::core::arctanh, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::arctanh(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1053,7 +1099,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "degrees", - &mlx::core::degrees, + [](const ScalarOrArray& a, StreamOrDevice s) { + return degrees(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1070,7 +1118,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "radians", - &mlx::core::radians, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::radians(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1087,7 +1137,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log", - &mlx::core::log, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::log(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1104,7 +1156,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log2", - &mlx::core::log2, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::log2(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1121,7 +1175,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log10", - &mlx::core::log10, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::log10(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1138,7 +1194,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log1p", - &mlx::core::log1p, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::log1p(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1176,7 +1234,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sigmoid", - &sigmoid, + [](const ScalarOrArray& a, StreamOrDevice s) { + return sigmoid(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1837,7 +1897,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "floor", - &mlx::core::floor, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::floor(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1854,7 +1916,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "ceil", - &mlx::core::ceil, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::ceil(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1871,7 +1935,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isnan", - &mlx::core::isnan, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::isnan(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1888,7 +1954,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isinf", - &mlx::core::isinf, + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::isinf(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1905,7 +1973,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isposinf", - &isposinf, + [](const ScalarOrArray& a, StreamOrDevice s) { + return isposinf(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1923,7 +1993,9 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isneginf", - &isneginf, + [](const ScalarOrArray& a, StreamOrDevice s) { + return isneginf(to_array(a), s); + }, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -3409,8 +3481,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "round", - [](const array& a, int decimals, StreamOrDevice s) { - return round(a, decimals, s); + [](const ScalarOrArray& a, int decimals, StreamOrDevice s) { + return round(to_array(a), decimals, s); }, nb::arg(), "decimals"_a = 0, @@ -3750,15 +3822,15 @@ void init_ops(nb::module_& m) { R"pbdoc( Matrix multiplication with matrix-level gather. - Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. + Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``. - + The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively. - For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, + For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)`` - For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, + For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)`` Args: diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 76a77ccd2..b141e6e3c 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1217,6 +1217,54 @@ class TestOps(mlx_tests.MLXTestCase): y_ = mx.array(x_) test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) + def test_unary_ops_from_non_array(self): + unary_ops = [ + "abs", + "exp", + "log", + "square", + "sqrt", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "tanh", + "sign", + "negative", + "expm1", + "arcsin", + "arccos", + "arctan", + "arcsinh", + "arctanh", + "degrees", + "radians", + "log2", + "log10", + "log1p", + "floor", + "ceil", + ] + + x = 0.5 + x_np = np.random.rand(10).astype(np.float32) + for op in unary_ops: + with self.subTest(op=op): + # Test from scalar + expected = getattr(np, op)(x) + out = getattr(mx, op)(x) + + # Check close + self.assertTrue(np.allclose(expected, out, equal_nan=True)) + + # Test from NumPy + expected = getattr(np, op)(x_np) + out = getattr(mx, op)(x_np) + + # Check close + self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True)) + def test_trig_ops(self): def test_ops(npop, mlxop, x, y, atol): r_np = npop(x)