From 32146296015c98ebcc5891abd2fbeb351952b3f6 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 11 Dec 2023 13:42:55 -0800 Subject: [PATCH] Mlx array accessor (#128) * Add an accessor to interoperate with custom types * Change the docs to custom signatures --- python/src/array.cpp | 78 +++++---- python/src/ops.cpp | 368 ++++++++++++++++++++++++++++++++----------- python/src/utils.h | 29 ++-- 3 files changed, 342 insertions(+), 133 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 86253f3f9..6a3ffbf96 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -458,38 +458,54 @@ void init_array(py::module_& m) { m.attr("bfloat16") = py::cast(bfloat16); m.attr("complex64") = py::cast(complex64); - py::class_(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc") - .def( - py::init([](ScalarOrArray v, std::optional t) { - auto arr = to_array(v, t); + auto array_class = py::class_( + m, "array", R"pbdoc(An N-dimensional array object.)pbdoc"); + + { + py::options options; + options.disable_function_signatures(); + + array_class.def( + py::init([](std::variant< + py::bool_, + py::int_, + py::float_, + std::complex, + py::list, + py::tuple, + py::array, + py::buffer, + py::object> v, + std::optional t) { + if (auto pv = std::get_if(&v); pv) { + return array(py::cast(*pv), t.value_or(bool_)); + } else if (auto pv = std::get_if(&v); pv) { + return array(py::cast(*pv), t.value_or(int32)); + } else if (auto pv = std::get_if(&v); pv) { + return array(py::cast(*pv), t.value_or(float32)); + } else if (auto pv = std::get_if>(&v); pv) { + return array(static_cast(*pv), t.value_or(complex64)); + } else if (auto pv = std::get_if(&v); pv) { + return array_from_list(*pv, t); + } else if (auto pv = std::get_if(&v); pv) { + return array_from_list(*pv, t); + } else if (auto pv = std::get_if(&v); pv) { + return np_array_to_mlx(*pv, t); + } else if (auto pv = std::get_if(&v); pv) { + return np_array_to_mlx(*pv, t); + } else { + auto arr = to_array_with_accessor(std::get(v)); return astype(arr, t.value_or(arr.dtype())); - }), - "val"_a, - "dtype"_a = std::nullopt) - .def( - py::init([](std::variant pl, - std::optional dtype) { - if (auto pv = std::get_if(&pl); pv) { - return array_from_list(*pv, dtype); - } else { - auto v = std::get(pl); - return array_from_list(v, dtype); - } - }), - "vals"_a, - "dtype"_a = std::nullopt) - .def( - py::init([](py::array np_array, std::optional dtype) { - return np_array_to_mlx(np_array, dtype); - }), - "vals"_a, - "dtype"_a = std::nullopt) - .def( - py::init([](py::buffer np_buffer, std::optional dtype) { - return np_array_to_mlx(np_buffer, dtype); - }), - "vals"_a, - "dtype"_a = std::nullopt) + } + }), + "val"_a, + "dtype"_a = std::nullopt, + R"pbdoc( + __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None) + )pbdoc"); + } + + array_class .def_property_readonly( "size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc") .def_property_readonly( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e25da7f38..b9eacea98 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -36,6 +36,9 @@ double scalar_to_double(Scalar s) { } void init_ops(py::module_& m) { + py::options options; + options.disable_function_signatures(); + m.def( "reshape", &reshape, @@ -45,6 +48,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + reshape(a: array, /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array + Reshape an array while preserving the size. Args: @@ -73,6 +78,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + squeeze(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array + Remove length one axes from an array. Args: @@ -100,6 +107,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + expand_dims(a: array, /, axis: Union[int, List[int]], *, stream: Union[None, Stream, Device] = None) -> array + Add a size one dimension at the given axis. Args: @@ -117,6 +126,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise absolute value. Args: @@ -133,6 +144,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise sign. Args: @@ -149,6 +162,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise negation. Args: @@ -169,6 +184,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise addition. Add two arrays with numpy-style broadcasting semantics. Either or both input arrays @@ -193,6 +210,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise subtraction. Subtract one array from another with numpy-style broadcasting semantics. Either or both @@ -217,6 +236,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise multiplication. Multiply two arrays with numpy-style broadcasting semantics. Either or both @@ -241,6 +262,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise division. Divide two arrays with numpy-style broadcasting semantics. Either or both @@ -265,6 +288,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise remainder of division. Computes the remainder of dividing a with b with numpy-style @@ -290,6 +315,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise equality. Equality comparison on two arrays with numpy-style broadcasting semantics. @@ -314,6 +341,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise not equal. Not equal comparison on two arrays with numpy-style broadcasting semantics. @@ -338,6 +367,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise less than. Strict less than on two arrays with numpy-style broadcasting semantics. @@ -362,6 +393,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise less than or equal. Less than or equal on two arrays with numpy-style broadcasting semantics. @@ -386,6 +419,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise greater than. Strict greater than on two arrays with numpy-style broadcasting semantics. @@ -410,6 +445,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + Element-wise greater or equal. Greater than or equal on two arrays with numpy-style broadcasting semantics. @@ -438,6 +475,8 @@ void init_ops(py::module_& m) { "equal_nan"_a = false, "stream"_a = none, R"pbdoc( + array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array + Array equality check. Compare two arrays for equality. Returns ``True`` if and only if the arrays @@ -462,6 +501,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Matrix multiplication. Perform the (possibly batched) matrix multiplication of two arrays. This function supports @@ -492,6 +533,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise square. Args: @@ -508,6 +551,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise square root. Args: @@ -524,6 +569,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise reciprocal and square root. Args: @@ -540,6 +587,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise reciprocal. Args: @@ -558,6 +607,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise logical not. Args: @@ -578,6 +629,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise log-add-exp. This is a numerically stable log-add-exp of two arrays with numpy-style @@ -600,6 +653,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise exponential. Args: @@ -616,6 +671,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise error function. .. math:: @@ -635,6 +692,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise inverse of :func:`erf`. Args: @@ -651,6 +710,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise sine. Args: @@ -667,6 +728,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise cosine. Args: @@ -683,6 +746,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise tangent. Args: @@ -699,6 +764,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise inverse sine. Args: @@ -715,6 +782,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise inverse cosine. Args: @@ -731,6 +800,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise inverse tangent. Args: @@ -747,6 +818,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise hyperbolic sine. Args: @@ -763,6 +836,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise hyperbolic cosine. Args: @@ -779,6 +854,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise hyperbolic tangent. Args: @@ -795,6 +872,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise inverse hyperbolic sine. Args: @@ -811,6 +890,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise inverse hyperbolic cosine. Args: @@ -827,6 +908,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise inverse hyperbolic tangent. Args: @@ -843,6 +926,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise natural logarithm. Args: @@ -859,6 +944,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise base-2 logarithm. Args: @@ -875,6 +962,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise base-10 logarithm. Args: @@ -891,6 +980,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise natural log of one plus the array. Args: @@ -907,6 +998,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Stop gradients from being computed. The operation is the identity but it prevents gradients from flowing @@ -927,6 +1020,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise logistic sigmoid. The logistic sigmoid function is: @@ -952,6 +1047,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise power operation. Raise the elements of a to the powers in elements of b with numpy-style @@ -964,108 +1061,102 @@ void init_ops(py::module_& m) { Returns: array: Bases of ``a`` raised to powers in ``b``. )pbdoc"); - { - // Disable function signature just for arange which we write manually - py::options options; - options.disable_function_signatures(); - m.def( - "arange", - [](Scalar stop, std::optional dtype_, StreamOrDevice s) { - Dtype dtype = - dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop); + m.def( + "arange", + [](Scalar stop, std::optional dtype_, StreamOrDevice s) { + Dtype dtype = + dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop); - return arange(0.0, scalar_to_double(stop), 1.0, dtype, s); - }, - "stop"_a, - "dtype"_a = none, - "stream"_a = none); - m.def( - "arange", - [](Scalar start, - Scalar stop, - std::optional dtype_, - StreamOrDevice s) { - Dtype dtype = dtype_.has_value() - ? dtype_.value() - : promote_types(scalar_to_dtype(start), scalar_to_dtype(stop)); - return arange( - scalar_to_double(start), scalar_to_double(stop), dtype, s); - }, - "start"_a, - "stop"_a, - "dtype"_a = none, - "stream"_a = none); - m.def( - "arange", - [](Scalar stop, - Scalar step, - std::optional dtype_, - StreamOrDevice s) { - Dtype dtype = dtype_.has_value() - ? dtype_.value() - : promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)); + return arange(0.0, scalar_to_double(stop), 1.0, dtype, s); + }, + "stop"_a, + "dtype"_a = none, + "stream"_a = none); + m.def( + "arange", + [](Scalar start, + Scalar stop, + std::optional dtype_, + StreamOrDevice s) { + Dtype dtype = dtype_.has_value() + ? dtype_.value() + : promote_types(scalar_to_dtype(start), scalar_to_dtype(stop)); + return arange( + scalar_to_double(start), scalar_to_double(stop), dtype, s); + }, + "start"_a, + "stop"_a, + "dtype"_a = none, + "stream"_a = none); + m.def( + "arange", + [](Scalar stop, + Scalar step, + std::optional dtype_, + StreamOrDevice s) { + Dtype dtype = dtype_.has_value() + ? dtype_.value() + : promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)); - return arange( - 0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s); - }, - "stop"_a, - "step"_a, - "dtype"_a = none, - "stream"_a = none); - m.def( - "arange", - [](Scalar start, - Scalar stop, - Scalar step, - std::optional dtype_, - StreamOrDevice s) { - // Determine the final dtype based on input types - Dtype dtype = dtype_.has_value() - ? dtype_.value() - : promote_types( - scalar_to_dtype(start), - promote_types( - scalar_to_dtype(stop), scalar_to_dtype(step))); + return arange( + 0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s); + }, + "stop"_a, + "step"_a, + "dtype"_a = none, + "stream"_a = none); + m.def( + "arange", + [](Scalar start, + Scalar stop, + Scalar step, + std::optional dtype_, + StreamOrDevice s) { + // Determine the final dtype based on input types + Dtype dtype = dtype_.has_value() + ? dtype_.value() + : promote_types( + scalar_to_dtype(start), + promote_types(scalar_to_dtype(stop), scalar_to_dtype(step))); - return arange( - scalar_to_double(start), - scalar_to_double(stop), - scalar_to_double(step), - dtype, - s); - }, - "start"_a, - "stop"_a, - "step"_a, - "dtype"_a = none, - "stream"_a = none, - R"pbdoc( - arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + return arange( + scalar_to_double(start), + scalar_to_double(stop), + scalar_to_double(step), + dtype, + s); + }, + "start"_a, + "stop"_a, + "step"_a, + "dtype"_a = none, + "stream"_a = none, + R"pbdoc( + arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array - Generates ranges of numbers. + Generates ranges of numbers. - Generate numbers in the half-open interval ``[start, stop)`` in - increments of ``step``. + Generate numbers in the half-open interval ``[start, stop)`` in + increments of ``step``. - Args: - start (float or int, optional): Starting value which defaults to ``0``. - stop (float or int): Stopping value. - step (float or int, optional): Increment which defaults to ``1``. - dtype (Dtype, optional): Specifies the data type of the output. - If unspecified will default to ``float32`` if any of ``start``, - ``stop``, or ``step`` are ``float``. Otherwise will default to - ``int32``. + Args: + start (float or int, optional): Starting value which defaults to ``0``. + stop (float or int): Stopping value. + step (float or int, optional): Increment which defaults to ``1``. + dtype (Dtype, optional): Specifies the data type of the output. + If unspecified will default to ``float32`` if any of ``start``, + ``stop``, or ``step`` are ``float``. Otherwise will default to + ``int32``. - Returns: - array: The range of values. + Returns: + array: The range of values. - Note: - Following the Numpy convention the actual increment used to - generate numbers is ``dtype(start + step) - dtype(start)``. - This can lead to unexpected results for example if `start + step` - is a fractional value and the `dtype` is integral. + Note: + Following the Numpy convention the actual increment used to + generate numbers is ``dtype(start + step) - dtype(start)``. + This can lead to unexpected results for example if `start + step` + is a fractional value and the `dtype` is integral. )pbdoc"); - } m.def( "take", [](const array& a, @@ -1085,6 +1176,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array + Take elements along an axis. The elements are taken from ``indices`` along the specified axis. @@ -1121,6 +1214,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array + Take values along an axis at the specified indices. Args: @@ -1153,6 +1248,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + Construct an array with the given value. Constructs an array of size ``shape`` filled with ``vals``. If ``vals`` @@ -1184,6 +1281,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + Construct an array of zeros. Args: @@ -1202,6 +1301,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + An array of zeros like the input. Args: @@ -1227,6 +1328,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + Construct an array of ones. Args: @@ -1245,6 +1348,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array + An array of ones like the input. Args: @@ -1312,6 +1417,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array + Approximate comparison of two arrays. The arrays are considered equal if: @@ -1347,6 +1454,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + An `and` reduction over the given axes. Args: @@ -1375,6 +1484,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + An `or` reduction over the given axes. Args: @@ -1397,8 +1508,11 @@ void init_ops(py::module_& m) { "a"_a, "b"_a, py::pos_only(), + py::kw_only(), "stream"_a = none, R"pbdoc( + minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise minimum. Take the element-wise min of two arrays with numpy-style broadcasting @@ -1423,6 +1537,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array + Element-wise maximum. Take the element-wise max of two arrays with numpy-style broadcasting @@ -1452,6 +1568,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array + Transpose the dimensions of the array. Args: @@ -1477,6 +1595,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + Sum reduce the array over the given axes. Args: @@ -1505,6 +1625,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + An product reduction over the given axes. Args: @@ -1533,6 +1655,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + An `min` reduction over the given axes. Args: @@ -1561,6 +1685,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + An `max` reduction over the given axes. Args: @@ -1589,6 +1715,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + A `log-sum-exp` reduction over the given axes. The log-sum-exp reduction is a numerically stable version of: @@ -1623,6 +1751,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + Compute the mean(s) over the given axes. Args: @@ -1653,6 +1783,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array + Compute the variance(s) over the given axes. Args: @@ -1688,6 +1820,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array + Split an array along a given axis. Args: @@ -1721,6 +1855,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + Indices of the minimum values along the axis. Args: @@ -1752,6 +1888,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + Indices of the maximum values along the axis. Args: @@ -1779,6 +1917,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array + Returns a sorted copy of the array. Args: @@ -1805,6 +1945,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array + Returns the indices that sort the array. Args: @@ -1832,6 +1974,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array + Returns a partitioned copy of the array such that the smaller ``kth`` elements are first. @@ -1866,6 +2010,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array + Returns the indices that partition the array. The ordering of the elements within a partition in given by the indices @@ -1901,6 +2047,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array + Returns the ``k`` largest elements from the input along a given axis. The elements will not necessarily be in sorted order. @@ -1926,6 +2074,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array + Broadcast an array to the given shape. The broadcasting semantics are the same as Numpy. @@ -1948,6 +2098,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array + Perform the softmax along the given axis. This operation is a numerically stable version of: @@ -1982,6 +2134,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array + Concatenate the arrays along the given axis. Args: @@ -2024,6 +2178,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array + Pad an array with a constant value Args: @@ -2067,6 +2223,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array + Create a view into the array with the given shape and strides. The resulting array will always be as if the provided array was row @@ -2113,6 +2271,8 @@ void init_ops(py::module_& m) { "inclusive"_a = true, "stream"_a = none, R"pbdoc( + cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array + Return the cumulative sum of the elements along the given axis. Args: @@ -2145,6 +2305,8 @@ void init_ops(py::module_& m) { "inclusive"_a = true, "stream"_a = none, R"pbdoc( + cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array + Return the cumulative product of the elements along the given axis. Args: @@ -2177,6 +2339,8 @@ void init_ops(py::module_& m) { "inclusive"_a = true, "stream"_a = none, R"pbdoc( + cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array + Return the cumulative maximum of the elements along the given axis. Args: @@ -2209,6 +2373,8 @@ void init_ops(py::module_& m) { "inclusive"_a = true, "stream"_a = none, R"pbdoc( + cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array + Return the cumulative minimum of the elements along the given axis. Args: @@ -2275,6 +2441,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array + The discrete convolution of 1D arrays. If ``v`` is longer than ``a``, then they are swapped. @@ -2301,6 +2469,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array + 1D convolution over an input with several channels Note: Only the default ``groups=1`` is currently supported. @@ -2360,6 +2530,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array + 2D convolution over an input with several channels Note: Only the default ``groups=1`` is currently supported. @@ -2390,6 +2562,8 @@ void init_ops(py::module_& m) { "retain_graph"_a = true, py::kw_only(), R"pbdoc( + save(file: str, arr: array, / , retain_graph: bool = True) + Save the array to a binary file in ``.npy`` format. Args: @@ -2408,6 +2582,8 @@ void init_ops(py::module_& m) { py::pos_only(), py::kw_only(), R"pbdoc( + savez(file: str, *args, **kwargs) + Save several arrays to a binary file in uncompressed ``.npz`` format. .. code-block:: python @@ -2440,6 +2616,8 @@ void init_ops(py::module_& m) { py::pos_only(), py::kw_only(), R"pbdoc( + savez_compressed(file: str, *args, **kwargs) + Save several arrays to a binary file in compressed ``.npz`` format. Args: @@ -2457,6 +2635,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] + Load array(s) from a binary file in ``.npy`` or ``.npz`` format. Args: @@ -2481,6 +2661,8 @@ void init_ops(py::module_& m) { py::kw_only(), "stream"_a = none, R"pbdoc( + where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array + Select from ``x`` or ``y`` according to ``condition``. The condition and input arrays must be the same shape or broadcastable diff --git a/python/src/utils.h b/python/src/utils.h index 142f9a985..5ac878979 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -15,8 +15,8 @@ namespace py = pybind11; using namespace mlx::core; using IntOrVec = std::variant>; -using ScalarOrArray = - std::variant, array>; +using ScalarOrArray = std:: + variant, py::object>; static constexpr std::monostate none{}; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { @@ -32,6 +32,14 @@ inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { return axes; } +inline array to_array_with_accessor(py::object obj) { + if (py::hasattr(obj, "__mlx_array__")) { + return obj.attr("__mlx_array__")().cast(); + } else { + return obj.cast(); + } +} + inline array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt) { @@ -48,7 +56,7 @@ inline array to_array( } else if (auto pv = std::get_if>(&v); pv) { return array(static_cast(*pv), complex64); } else { - return std::get(v); + return to_array_with_accessor(std::get(v)); } } @@ -60,13 +68,16 @@ inline std::pair to_arrays( // - If a is an array but b is not, treat b as a weak python type // - If b is an array but a is not, treat a as a weak python type // - If neither is an array convert to arrays but leave their types alone - if (auto pa = std::get_if(&a); pa) { - if (auto pb = std::get_if(&b); pb) { - return {*pa, *pb}; + if (auto pa = std::get_if(&a); pa) { + auto arr_a = to_array_with_accessor(*pa); + if (auto pb = std::get_if(&b); pb) { + auto arr_b = to_array_with_accessor(*pb); + return {arr_a, arr_b}; } - return {*pa, to_array(b, pa->dtype())}; - } else if (auto pb = std::get_if(&b); pb) { - return {to_array(a, pb->dtype()), *pb}; + return {arr_a, to_array(b, arr_a.dtype())}; + } else if (auto pb = std::get_if(&b); pb) { + auto arr_b = to_array_with_accessor(*pb); + return {to_array(a, arr_b.dtype()), arr_b}; } else { return {to_array(a), to_array(b)}; }