Mlx array accessor (#128)

* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
This commit is contained in:
Angelos Katharopoulos 2023-12-11 13:42:55 -08:00 committed by GitHub
parent 072044e28f
commit 3214629601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 342 additions and 133 deletions

View File

@ -458,38 +458,54 @@ void init_array(py::module_& m) {
m.attr("bfloat16") = py::cast(bfloat16); m.attr("bfloat16") = py::cast(bfloat16);
m.attr("complex64") = py::cast(complex64); m.attr("complex64") = py::cast(complex64);
py::class_<array>(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc") auto array_class = py::class_<array>(
.def( m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
py::init([](ScalarOrArray v, std::optional<Dtype> t) {
auto arr = to_array(v, t); {
return astype(arr, t.value_or(arr.dtype())); py::options options;
}), options.disable_function_signatures();
"val"_a,
"dtype"_a = std::nullopt) array_class.def(
.def( py::init([](std::variant<
py::init([](std::variant<py::list, py::tuple> pl, py::bool_,
std::optional<Dtype> dtype) { py::int_,
if (auto pv = std::get_if<py::list>(&pl); pv) { py::float_,
return array_from_list(*pv, dtype); std::complex<float>,
py::list,
py::tuple,
py::array,
py::buffer,
py::object> v,
std::optional<Dtype> t) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), t.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
return array(py::cast<int>(*pv), t.value_or(int32));
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
return array(py::cast<float>(*pv), t.value_or(float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
} else if (auto pv = std::get_if<py::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::array>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else { } else {
auto v = std::get<py::tuple>(pl); auto arr = to_array_with_accessor(std::get<py::object>(v));
return array_from_list(v, dtype); return astype(arr, t.value_or(arr.dtype()));
} }
}), }),
"vals"_a, "val"_a,
"dtype"_a = std::nullopt) "dtype"_a = std::nullopt,
.def( R"pbdoc(
py::init([](py::array np_array, std::optional<Dtype> dtype) { __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)
return np_array_to_mlx(np_array, dtype); )pbdoc");
}), }
"vals"_a,
"dtype"_a = std::nullopt) array_class
.def(
py::init([](py::buffer np_buffer, std::optional<Dtype> dtype) {
return np_array_to_mlx(np_buffer, dtype);
}),
"vals"_a,
"dtype"_a = std::nullopt)
.def_property_readonly( .def_property_readonly(
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc") "size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
.def_property_readonly( .def_property_readonly(

View File

@ -36,6 +36,9 @@ double scalar_to_double(Scalar s) {
} }
void init_ops(py::module_& m) { void init_ops(py::module_& m) {
py::options options;
options.disable_function_signatures();
m.def( m.def(
"reshape", "reshape",
&reshape, &reshape,
@ -45,6 +48,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
reshape(a: array, /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
Reshape an array while preserving the size. Reshape an array while preserving the size.
Args: Args:
@ -73,6 +78,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
squeeze(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
Remove length one axes from an array. Remove length one axes from an array.
Args: Args:
@ -100,6 +107,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
expand_dims(a: array, /, axis: Union[int, List[int]], *, stream: Union[None, Stream, Device] = None) -> array
Add a size one dimension at the given axis. Add a size one dimension at the given axis.
Args: Args:
@ -117,6 +126,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise absolute value. Element-wise absolute value.
Args: Args:
@ -133,6 +144,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise sign. Element-wise sign.
Args: Args:
@ -149,6 +162,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise negation. Element-wise negation.
Args: Args:
@ -169,6 +184,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise addition. Element-wise addition.
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays Add two arrays with numpy-style broadcasting semantics. Either or both input arrays
@ -193,6 +210,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise subtraction. Element-wise subtraction.
Subtract one array from another with numpy-style broadcasting semantics. Either or both Subtract one array from another with numpy-style broadcasting semantics. Either or both
@ -217,6 +236,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise multiplication. Element-wise multiplication.
Multiply two arrays with numpy-style broadcasting semantics. Either or both Multiply two arrays with numpy-style broadcasting semantics. Either or both
@ -241,6 +262,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise division. Element-wise division.
Divide two arrays with numpy-style broadcasting semantics. Either or both Divide two arrays with numpy-style broadcasting semantics. Either or both
@ -265,6 +288,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise remainder of division. Element-wise remainder of division.
Computes the remainder of dividing a with b with numpy-style Computes the remainder of dividing a with b with numpy-style
@ -290,6 +315,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise equality. Element-wise equality.
Equality comparison on two arrays with numpy-style broadcasting semantics. Equality comparison on two arrays with numpy-style broadcasting semantics.
@ -314,6 +341,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise not equal. Element-wise not equal.
Not equal comparison on two arrays with numpy-style broadcasting semantics. Not equal comparison on two arrays with numpy-style broadcasting semantics.
@ -338,6 +367,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise less than. Element-wise less than.
Strict less than on two arrays with numpy-style broadcasting semantics. Strict less than on two arrays with numpy-style broadcasting semantics.
@ -362,6 +393,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise less than or equal. Element-wise less than or equal.
Less than or equal on two arrays with numpy-style broadcasting semantics. Less than or equal on two arrays with numpy-style broadcasting semantics.
@ -386,6 +419,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise greater than. Element-wise greater than.
Strict greater than on two arrays with numpy-style broadcasting semantics. Strict greater than on two arrays with numpy-style broadcasting semantics.
@ -410,6 +445,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise greater or equal. Element-wise greater or equal.
Greater than or equal on two arrays with numpy-style broadcasting semantics. Greater than or equal on two arrays with numpy-style broadcasting semantics.
@ -438,6 +475,8 @@ void init_ops(py::module_& m) {
"equal_nan"_a = false, "equal_nan"_a = false,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
Array equality check. Array equality check.
Compare two arrays for equality. Returns ``True`` if and only if the arrays Compare two arrays for equality. Returns ``True`` if and only if the arrays
@ -462,6 +501,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Matrix multiplication. Matrix multiplication.
Perform the (possibly batched) matrix multiplication of two arrays. This function supports Perform the (possibly batched) matrix multiplication of two arrays. This function supports
@ -492,6 +533,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise square. Element-wise square.
Args: Args:
@ -508,6 +551,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise square root. Element-wise square root.
Args: Args:
@ -524,6 +569,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise reciprocal and square root. Element-wise reciprocal and square root.
Args: Args:
@ -540,6 +587,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise reciprocal. Element-wise reciprocal.
Args: Args:
@ -558,6 +607,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logical not. Element-wise logical not.
Args: Args:
@ -578,6 +629,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise log-add-exp. Element-wise log-add-exp.
This is a numerically stable log-add-exp of two arrays with numpy-style This is a numerically stable log-add-exp of two arrays with numpy-style
@ -600,6 +653,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise exponential. Element-wise exponential.
Args: Args:
@ -616,6 +671,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise error function. Element-wise error function.
.. math:: .. math::
@ -635,6 +692,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse of :func:`erf`. Element-wise inverse of :func:`erf`.
Args: Args:
@ -651,6 +710,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise sine. Element-wise sine.
Args: Args:
@ -667,6 +728,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise cosine. Element-wise cosine.
Args: Args:
@ -683,6 +746,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise tangent. Element-wise tangent.
Args: Args:
@ -699,6 +764,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse sine. Element-wise inverse sine.
Args: Args:
@ -715,6 +782,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse cosine. Element-wise inverse cosine.
Args: Args:
@ -731,6 +800,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse tangent. Element-wise inverse tangent.
Args: Args:
@ -747,6 +818,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise hyperbolic sine. Element-wise hyperbolic sine.
Args: Args:
@ -763,6 +836,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise hyperbolic cosine. Element-wise hyperbolic cosine.
Args: Args:
@ -779,6 +854,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise hyperbolic tangent. Element-wise hyperbolic tangent.
Args: Args:
@ -795,6 +872,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse hyperbolic sine. Element-wise inverse hyperbolic sine.
Args: Args:
@ -811,6 +890,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse hyperbolic cosine. Element-wise inverse hyperbolic cosine.
Args: Args:
@ -827,6 +908,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse hyperbolic tangent. Element-wise inverse hyperbolic tangent.
Args: Args:
@ -843,6 +926,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise natural logarithm. Element-wise natural logarithm.
Args: Args:
@ -859,6 +944,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise base-2 logarithm. Element-wise base-2 logarithm.
Args: Args:
@ -875,6 +962,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise base-10 logarithm. Element-wise base-10 logarithm.
Args: Args:
@ -891,6 +980,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise natural log of one plus the array. Element-wise natural log of one plus the array.
Args: Args:
@ -907,6 +998,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Stop gradients from being computed. Stop gradients from being computed.
The operation is the identity but it prevents gradients from flowing The operation is the identity but it prevents gradients from flowing
@ -927,6 +1020,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logistic sigmoid. Element-wise logistic sigmoid.
The logistic sigmoid function is: The logistic sigmoid function is:
@ -952,6 +1047,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise power operation. Element-wise power operation.
Raise the elements of a to the powers in elements of b with numpy-style Raise the elements of a to the powers in elements of b with numpy-style
@ -964,10 +1061,6 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: Bases of ``a`` raised to powers in ``b``. array: Bases of ``a`` raised to powers in ``b``.
)pbdoc"); )pbdoc");
{
// Disable function signature just for arange which we write manually
py::options options;
options.disable_function_signatures();
m.def( m.def(
"arange", "arange",
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) { [](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
@ -1024,8 +1117,7 @@ void init_ops(py::module_& m) {
? dtype_.value() ? dtype_.value()
: promote_types( : promote_types(
scalar_to_dtype(start), scalar_to_dtype(start),
promote_types( promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
scalar_to_dtype(stop), scalar_to_dtype(step)));
return arange( return arange(
scalar_to_double(start), scalar_to_double(start),
@ -1065,7 +1157,6 @@ void init_ops(py::module_& m) {
This can lead to unexpected results for example if `start + step` This can lead to unexpected results for example if `start + step`
is a fractional value and the `dtype` is integral. is a fractional value and the `dtype` is integral.
)pbdoc"); )pbdoc");
}
m.def( m.def(
"take", "take",
[](const array& a, [](const array& a,
@ -1085,6 +1176,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
Take elements along an axis. Take elements along an axis.
The elements are taken from ``indices`` along the specified axis. The elements are taken from ``indices`` along the specified axis.
@ -1121,6 +1214,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
Take values along an axis at the specified indices. Take values along an axis at the specified indices.
Args: Args:
@ -1153,6 +1248,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array with the given value. Construct an array with the given value.
Constructs an array of size ``shape`` filled with ``vals``. If ``vals`` Constructs an array of size ``shape`` filled with ``vals``. If ``vals``
@ -1184,6 +1281,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of zeros. Construct an array of zeros.
Args: Args:
@ -1202,6 +1301,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
An array of zeros like the input. An array of zeros like the input.
Args: Args:
@ -1227,6 +1328,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of ones. Construct an array of ones.
Args: Args:
@ -1245,6 +1348,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
An array of ones like the input. An array of ones like the input.
Args: Args:
@ -1312,6 +1417,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
Approximate comparison of two arrays. Approximate comparison of two arrays.
The arrays are considered equal if: The arrays are considered equal if:
@ -1347,6 +1454,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `and` reduction over the given axes. An `and` reduction over the given axes.
Args: Args:
@ -1375,6 +1484,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `or` reduction over the given axes. An `or` reduction over the given axes.
Args: Args:
@ -1397,8 +1508,11 @@ void init_ops(py::module_& m) {
"a"_a, "a"_a,
"b"_a, "b"_a,
py::pos_only(), py::pos_only(),
py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise minimum. Element-wise minimum.
Take the element-wise min of two arrays with numpy-style broadcasting Take the element-wise min of two arrays with numpy-style broadcasting
@ -1423,6 +1537,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise maximum. Element-wise maximum.
Take the element-wise max of two arrays with numpy-style broadcasting Take the element-wise max of two arrays with numpy-style broadcasting
@ -1452,6 +1568,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
Transpose the dimensions of the array. Transpose the dimensions of the array.
Args: Args:
@ -1477,6 +1595,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Sum reduce the array over the given axes. Sum reduce the array over the given axes.
Args: Args:
@ -1505,6 +1625,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An product reduction over the given axes. An product reduction over the given axes.
Args: Args:
@ -1533,6 +1655,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `min` reduction over the given axes. An `min` reduction over the given axes.
Args: Args:
@ -1561,6 +1685,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `max` reduction over the given axes. An `max` reduction over the given axes.
Args: Args:
@ -1589,6 +1715,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
A `log-sum-exp` reduction over the given axes. A `log-sum-exp` reduction over the given axes.
The log-sum-exp reduction is a numerically stable version of: The log-sum-exp reduction is a numerically stable version of:
@ -1623,6 +1751,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Compute the mean(s) over the given axes. Compute the mean(s) over the given axes.
Args: Args:
@ -1653,6 +1783,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Compute the variance(s) over the given axes. Compute the variance(s) over the given axes.
Args: Args:
@ -1688,6 +1820,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Split an array along a given axis. Split an array along a given axis.
Args: Args:
@ -1721,6 +1855,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Indices of the minimum values along the axis. Indices of the minimum values along the axis.
Args: Args:
@ -1752,6 +1888,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Indices of the maximum values along the axis. Indices of the maximum values along the axis.
Args: Args:
@ -1779,6 +1917,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns a sorted copy of the array. Returns a sorted copy of the array.
Args: Args:
@ -1805,6 +1945,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns the indices that sort the array. Returns the indices that sort the array.
Args: Args:
@ -1832,6 +1974,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns a partitioned copy of the array such that the smaller ``kth`` Returns a partitioned copy of the array such that the smaller ``kth``
elements are first. elements are first.
@ -1866,6 +2010,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns the indices that partition the array. Returns the indices that partition the array.
The ordering of the elements within a partition in given by the indices The ordering of the elements within a partition in given by the indices
@ -1901,6 +2047,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns the ``k`` largest elements from the input along a given axis. Returns the ``k`` largest elements from the input along a given axis.
The elements will not necessarily be in sorted order. The elements will not necessarily be in sorted order.
@ -1926,6 +2074,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
Broadcast an array to the given shape. Broadcast an array to the given shape.
The broadcasting semantics are the same as Numpy. The broadcasting semantics are the same as Numpy.
@ -1948,6 +2098,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
Perform the softmax along the given axis. Perform the softmax along the given axis.
This operation is a numerically stable version of: This operation is a numerically stable version of:
@ -1982,6 +2134,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array
Concatenate the arrays along the given axis. Concatenate the arrays along the given axis.
Args: Args:
@ -2024,6 +2178,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array
Pad an array with a constant value Pad an array with a constant value
Args: Args:
@ -2067,6 +2223,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Create a view into the array with the given shape and strides. Create a view into the array with the given shape and strides.
The resulting array will always be as if the provided array was row The resulting array will always be as if the provided array was row
@ -2113,6 +2271,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true, "inclusive"_a = true,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative sum of the elements along the given axis. Return the cumulative sum of the elements along the given axis.
Args: Args:
@ -2145,6 +2305,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true, "inclusive"_a = true,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative product of the elements along the given axis. Return the cumulative product of the elements along the given axis.
Args: Args:
@ -2177,6 +2339,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true, "inclusive"_a = true,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative maximum of the elements along the given axis. Return the cumulative maximum of the elements along the given axis.
Args: Args:
@ -2209,6 +2373,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true, "inclusive"_a = true,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative minimum of the elements along the given axis. Return the cumulative minimum of the elements along the given axis.
Args: Args:
@ -2275,6 +2441,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array
The discrete convolution of 1D arrays. The discrete convolution of 1D arrays.
If ``v`` is longer than ``a``, then they are swapped. If ``v`` is longer than ``a``, then they are swapped.
@ -2301,6 +2469,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
1D convolution over an input with several channels 1D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported. Note: Only the default ``groups=1`` is currently supported.
@ -2360,6 +2530,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
2D convolution over an input with several channels 2D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported. Note: Only the default ``groups=1`` is currently supported.
@ -2390,6 +2562,8 @@ void init_ops(py::module_& m) {
"retain_graph"_a = true, "retain_graph"_a = true,
py::kw_only(), py::kw_only(),
R"pbdoc( R"pbdoc(
save(file: str, arr: array, / , retain_graph: bool = True)
Save the array to a binary file in ``.npy`` format. Save the array to a binary file in ``.npy`` format.
Args: Args:
@ -2408,6 +2582,8 @@ void init_ops(py::module_& m) {
py::pos_only(), py::pos_only(),
py::kw_only(), py::kw_only(),
R"pbdoc( R"pbdoc(
savez(file: str, *args, **kwargs)
Save several arrays to a binary file in uncompressed ``.npz`` format. Save several arrays to a binary file in uncompressed ``.npz`` format.
.. code-block:: python .. code-block:: python
@ -2440,6 +2616,8 @@ void init_ops(py::module_& m) {
py::pos_only(), py::pos_only(),
py::kw_only(), py::kw_only(),
R"pbdoc( R"pbdoc(
savez_compressed(file: str, *args, **kwargs)
Save several arrays to a binary file in compressed ``.npz`` format. Save several arrays to a binary file in compressed ``.npz`` format.
Args: Args:
@ -2457,6 +2635,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
Load array(s) from a binary file in ``.npy`` or ``.npz`` format. Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
Args: Args:
@ -2481,6 +2661,8 @@ void init_ops(py::module_& m) {
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Select from ``x`` or ``y`` according to ``condition``. Select from ``x`` or ``y`` according to ``condition``.
The condition and input arrays must be the same shape or broadcastable The condition and input arrays must be the same shape or broadcastable

View File

@ -15,8 +15,8 @@ namespace py = pybind11;
using namespace mlx::core; using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>; using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = using ScalarOrArray = std::
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>; variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{}; static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) { inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
@ -32,6 +32,14 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
return axes; return axes;
} }
inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>();
} else {
return obj.cast<array>();
}
}
inline array to_array( inline array to_array(
const ScalarOrArray& v, const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) { std::optional<Dtype> dtype = std::nullopt) {
@ -48,7 +56,7 @@ inline array to_array(
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64); return array(static_cast<complex64_t>(*pv), complex64);
} else { } else {
return std::get<array>(v); return to_array_with_accessor(std::get<py::object>(v));
} }
} }
@ -60,13 +68,16 @@ inline std::pair<array, array> to_arrays(
// - If a is an array but b is not, treat b as a weak python type // - If a is an array but b is not, treat b as a weak python type
// - If b is an array but a is not, treat a as a weak python type // - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone // - If neither is an array convert to arrays but leave their types alone
if (auto pa = std::get_if<array>(&a); pa) { if (auto pa = std::get_if<py::object>(&a); pa) {
if (auto pb = std::get_if<array>(&b); pb) { auto arr_a = to_array_with_accessor(*pa);
return {*pa, *pb}; if (auto pb = std::get_if<py::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {arr_a, arr_b};
} }
return {*pa, to_array(b, pa->dtype())}; return {arr_a, to_array(b, arr_a.dtype())};
} else if (auto pb = std::get_if<array>(&b); pb) { } else if (auto pb = std::get_if<py::object>(&b); pb) {
return {to_array(a, pb->dtype()), *pb}; auto arr_b = to_array_with_accessor(*pb);
return {to_array(a, arr_b.dtype()), arr_b};
} else { } else {
return {to_array(a), to_array(b)}; return {to_array(a), to_array(b)};
} }