Mlx array accessor (#128)

* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
This commit is contained in:
Angelos Katharopoulos 2023-12-11 13:42:55 -08:00 committed by GitHub
parent 072044e28f
commit 3214629601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 342 additions and 133 deletions

View File

@ -458,38 +458,54 @@ void init_array(py::module_& m) {
m.attr("bfloat16") = py::cast(bfloat16);
m.attr("complex64") = py::cast(complex64);
py::class_<array>(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc")
.def(
py::init([](ScalarOrArray v, std::optional<Dtype> t) {
auto arr = to_array(v, t);
auto array_class = py::class_<array>(
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
{
py::options options;
options.disable_function_signatures();
array_class.def(
py::init([](std::variant<
py::bool_,
py::int_,
py::float_,
std::complex<float>,
py::list,
py::tuple,
py::array,
py::buffer,
py::object> v,
std::optional<Dtype> t) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), t.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
return array(py::cast<int>(*pv), t.value_or(int32));
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
return array(py::cast<float>(*pv), t.value_or(float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
} else if (auto pv = std::get_if<py::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::array>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else {
auto arr = to_array_with_accessor(std::get<py::object>(v));
return astype(arr, t.value_or(arr.dtype()));
}),
"val"_a,
"dtype"_a = std::nullopt)
.def(
py::init([](std::variant<py::list, py::tuple> pl,
std::optional<Dtype> dtype) {
if (auto pv = std::get_if<py::list>(&pl); pv) {
return array_from_list(*pv, dtype);
} else {
auto v = std::get<py::tuple>(pl);
return array_from_list(v, dtype);
}
}),
"vals"_a,
"dtype"_a = std::nullopt)
.def(
py::init([](py::array np_array, std::optional<Dtype> dtype) {
return np_array_to_mlx(np_array, dtype);
}),
"vals"_a,
"dtype"_a = std::nullopt)
.def(
py::init([](py::buffer np_buffer, std::optional<Dtype> dtype) {
return np_array_to_mlx(np_buffer, dtype);
}),
"vals"_a,
"dtype"_a = std::nullopt)
}
}),
"val"_a,
"dtype"_a = std::nullopt,
R"pbdoc(
__init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)
)pbdoc");
}
array_class
.def_property_readonly(
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
.def_property_readonly(

View File

@ -36,6 +36,9 @@ double scalar_to_double(Scalar s) {
}
void init_ops(py::module_& m) {
py::options options;
options.disable_function_signatures();
m.def(
"reshape",
&reshape,
@ -45,6 +48,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
reshape(a: array, /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
Reshape an array while preserving the size.
Args:
@ -73,6 +78,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
squeeze(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
Remove length one axes from an array.
Args:
@ -100,6 +107,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
expand_dims(a: array, /, axis: Union[int, List[int]], *, stream: Union[None, Stream, Device] = None) -> array
Add a size one dimension at the given axis.
Args:
@ -117,6 +126,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise absolute value.
Args:
@ -133,6 +144,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise sign.
Args:
@ -149,6 +162,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise negation.
Args:
@ -169,6 +184,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise addition.
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays
@ -193,6 +210,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise subtraction.
Subtract one array from another with numpy-style broadcasting semantics. Either or both
@ -217,6 +236,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise multiplication.
Multiply two arrays with numpy-style broadcasting semantics. Either or both
@ -241,6 +262,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise division.
Divide two arrays with numpy-style broadcasting semantics. Either or both
@ -265,6 +288,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise remainder of division.
Computes the remainder of dividing a with b with numpy-style
@ -290,6 +315,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise equality.
Equality comparison on two arrays with numpy-style broadcasting semantics.
@ -314,6 +341,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise not equal.
Not equal comparison on two arrays with numpy-style broadcasting semantics.
@ -338,6 +367,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise less than.
Strict less than on two arrays with numpy-style broadcasting semantics.
@ -362,6 +393,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise less than or equal.
Less than or equal on two arrays with numpy-style broadcasting semantics.
@ -386,6 +419,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise greater than.
Strict greater than on two arrays with numpy-style broadcasting semantics.
@ -410,6 +445,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise greater or equal.
Greater than or equal on two arrays with numpy-style broadcasting semantics.
@ -438,6 +475,8 @@ void init_ops(py::module_& m) {
"equal_nan"_a = false,
"stream"_a = none,
R"pbdoc(
array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
Array equality check.
Compare two arrays for equality. Returns ``True`` if and only if the arrays
@ -462,6 +501,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Matrix multiplication.
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
@ -492,6 +533,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise square.
Args:
@ -508,6 +551,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise square root.
Args:
@ -524,6 +569,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise reciprocal and square root.
Args:
@ -540,6 +587,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise reciprocal.
Args:
@ -558,6 +607,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logical not.
Args:
@ -578,6 +629,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise log-add-exp.
This is a numerically stable log-add-exp of two arrays with numpy-style
@ -600,6 +653,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise exponential.
Args:
@ -616,6 +671,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise error function.
.. math::
@ -635,6 +692,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse of :func:`erf`.
Args:
@ -651,6 +710,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise sine.
Args:
@ -667,6 +728,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise cosine.
Args:
@ -683,6 +746,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise tangent.
Args:
@ -699,6 +764,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse sine.
Args:
@ -715,6 +782,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse cosine.
Args:
@ -731,6 +800,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse tangent.
Args:
@ -747,6 +818,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise hyperbolic sine.
Args:
@ -763,6 +836,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise hyperbolic cosine.
Args:
@ -779,6 +854,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise hyperbolic tangent.
Args:
@ -795,6 +872,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse hyperbolic sine.
Args:
@ -811,6 +890,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse hyperbolic cosine.
Args:
@ -827,6 +908,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise inverse hyperbolic tangent.
Args:
@ -843,6 +926,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise natural logarithm.
Args:
@ -859,6 +944,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise base-2 logarithm.
Args:
@ -875,6 +962,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise base-10 logarithm.
Args:
@ -891,6 +980,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise natural log of one plus the array.
Args:
@ -907,6 +998,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Stop gradients from being computed.
The operation is the identity but it prevents gradients from flowing
@ -927,6 +1020,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logistic sigmoid.
The logistic sigmoid function is:
@ -952,6 +1047,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise power operation.
Raise the elements of a to the powers in elements of b with numpy-style
@ -964,108 +1061,102 @@ void init_ops(py::module_& m) {
Returns:
array: Bases of ``a`` raised to powers in ``b``.
)pbdoc");
{
// Disable function signature just for arange which we write manually
py::options options;
options.disable_function_signatures();
m.def(
"arange",
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
Dtype dtype =
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
m.def(
"arange",
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
Dtype dtype =
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
},
"stop"_a,
"dtype"_a = none,
"stream"_a = none);
m.def(
"arange",
[](Scalar start,
Scalar stop,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
return arange(
scalar_to_double(start), scalar_to_double(stop), dtype, s);
},
"start"_a,
"stop"_a,
"dtype"_a = none,
"stream"_a = none);
m.def(
"arange",
[](Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
},
"stop"_a,
"dtype"_a = none,
"stream"_a = none);
m.def(
"arange",
[](Scalar start,
Scalar stop,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
return arange(
scalar_to_double(start), scalar_to_double(stop), dtype, s);
},
"start"_a,
"stop"_a,
"dtype"_a = none,
"stream"_a = none);
m.def(
"arange",
[](Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
return arange(
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
},
"stop"_a,
"step"_a,
"dtype"_a = none,
"stream"_a = none);
m.def(
"arange",
[](Scalar start,
Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
// Determine the final dtype based on input types
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(
scalar_to_dtype(start),
promote_types(
scalar_to_dtype(stop), scalar_to_dtype(step)));
return arange(
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
},
"stop"_a,
"step"_a,
"dtype"_a = none,
"stream"_a = none);
m.def(
"arange",
[](Scalar start,
Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
// Determine the final dtype based on input types
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(
scalar_to_dtype(start),
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
return arange(
scalar_to_double(start),
scalar_to_double(stop),
scalar_to_double(step),
dtype,
s);
},
"start"_a,
"stop"_a,
"step"_a,
"dtype"_a = none,
"stream"_a = none,
R"pbdoc(
arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
return arange(
scalar_to_double(start),
scalar_to_double(stop),
scalar_to_double(step),
dtype,
s);
},
"start"_a,
"stop"_a,
"step"_a,
"dtype"_a = none,
"stream"_a = none,
R"pbdoc(
arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Generates ranges of numbers.
Generates ranges of numbers.
Generate numbers in the half-open interval ``[start, stop)`` in
increments of ``step``.
Generate numbers in the half-open interval ``[start, stop)`` in
increments of ``step``.
Args:
start (float or int, optional): Starting value which defaults to ``0``.
stop (float or int): Stopping value.
step (float or int, optional): Increment which defaults to ``1``.
dtype (Dtype, optional): Specifies the data type of the output.
If unspecified will default to ``float32`` if any of ``start``,
``stop``, or ``step`` are ``float``. Otherwise will default to
``int32``.
Args:
start (float or int, optional): Starting value which defaults to ``0``.
stop (float or int): Stopping value.
step (float or int, optional): Increment which defaults to ``1``.
dtype (Dtype, optional): Specifies the data type of the output.
If unspecified will default to ``float32`` if any of ``start``,
``stop``, or ``step`` are ``float``. Otherwise will default to
``int32``.
Returns:
array: The range of values.
Returns:
array: The range of values.
Note:
Following the Numpy convention the actual increment used to
generate numbers is ``dtype(start + step) - dtype(start)``.
This can lead to unexpected results for example if `start + step`
is a fractional value and the `dtype` is integral.
Note:
Following the Numpy convention the actual increment used to
generate numbers is ``dtype(start + step) - dtype(start)``.
This can lead to unexpected results for example if `start + step`
is a fractional value and the `dtype` is integral.
)pbdoc");
}
m.def(
"take",
[](const array& a,
@ -1085,6 +1176,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
Take elements along an axis.
The elements are taken from ``indices`` along the specified axis.
@ -1121,6 +1214,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
Take values along an axis at the specified indices.
Args:
@ -1153,6 +1248,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array with the given value.
Constructs an array of size ``shape`` filled with ``vals``. If ``vals``
@ -1184,6 +1281,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of zeros.
Args:
@ -1202,6 +1301,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
An array of zeros like the input.
Args:
@ -1227,6 +1328,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of ones.
Args:
@ -1245,6 +1348,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
An array of ones like the input.
Args:
@ -1312,6 +1417,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
Approximate comparison of two arrays.
The arrays are considered equal if:
@ -1347,6 +1454,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `and` reduction over the given axes.
Args:
@ -1375,6 +1484,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `or` reduction over the given axes.
Args:
@ -1397,8 +1508,11 @@ void init_ops(py::module_& m) {
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise minimum.
Take the element-wise min of two arrays with numpy-style broadcasting
@ -1423,6 +1537,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise maximum.
Take the element-wise max of two arrays with numpy-style broadcasting
@ -1452,6 +1568,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
Transpose the dimensions of the array.
Args:
@ -1477,6 +1595,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Sum reduce the array over the given axes.
Args:
@ -1505,6 +1625,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An product reduction over the given axes.
Args:
@ -1533,6 +1655,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `min` reduction over the given axes.
Args:
@ -1561,6 +1685,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
An `max` reduction over the given axes.
Args:
@ -1589,6 +1715,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
A `log-sum-exp` reduction over the given axes.
The log-sum-exp reduction is a numerically stable version of:
@ -1623,6 +1751,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Compute the mean(s) over the given axes.
Args:
@ -1653,6 +1783,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Compute the variance(s) over the given axes.
Args:
@ -1688,6 +1820,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Split an array along a given axis.
Args:
@ -1721,6 +1855,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Indices of the minimum values along the axis.
Args:
@ -1752,6 +1888,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Indices of the maximum values along the axis.
Args:
@ -1779,6 +1917,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns a sorted copy of the array.
Args:
@ -1805,6 +1945,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns the indices that sort the array.
Args:
@ -1832,6 +1974,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns a partitioned copy of the array such that the smaller ``kth``
elements are first.
@ -1866,6 +2010,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns the indices that partition the array.
The ordering of the elements within a partition in given by the indices
@ -1901,6 +2047,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
Returns the ``k`` largest elements from the input along a given axis.
The elements will not necessarily be in sorted order.
@ -1926,6 +2074,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
Broadcast an array to the given shape.
The broadcasting semantics are the same as Numpy.
@ -1948,6 +2098,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
Perform the softmax along the given axis.
This operation is a numerically stable version of:
@ -1982,6 +2134,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array
Concatenate the arrays along the given axis.
Args:
@ -2024,6 +2178,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array
Pad an array with a constant value
Args:
@ -2067,6 +2223,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
Create a view into the array with the given shape and strides.
The resulting array will always be as if the provided array was row
@ -2113,6 +2271,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true,
"stream"_a = none,
R"pbdoc(
cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative sum of the elements along the given axis.
Args:
@ -2145,6 +2305,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true,
"stream"_a = none,
R"pbdoc(
cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative product of the elements along the given axis.
Args:
@ -2177,6 +2339,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true,
"stream"_a = none,
R"pbdoc(
cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative maximum of the elements along the given axis.
Args:
@ -2209,6 +2373,8 @@ void init_ops(py::module_& m) {
"inclusive"_a = true,
"stream"_a = none,
R"pbdoc(
cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
Return the cumulative minimum of the elements along the given axis.
Args:
@ -2275,6 +2441,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array
The discrete convolution of 1D arrays.
If ``v`` is longer than ``a``, then they are swapped.
@ -2301,6 +2469,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
1D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
@ -2360,6 +2530,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
2D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
@ -2390,6 +2562,8 @@ void init_ops(py::module_& m) {
"retain_graph"_a = true,
py::kw_only(),
R"pbdoc(
save(file: str, arr: array, / , retain_graph: bool = True)
Save the array to a binary file in ``.npy`` format.
Args:
@ -2408,6 +2582,8 @@ void init_ops(py::module_& m) {
py::pos_only(),
py::kw_only(),
R"pbdoc(
savez(file: str, *args, **kwargs)
Save several arrays to a binary file in uncompressed ``.npz`` format.
.. code-block:: python
@ -2440,6 +2616,8 @@ void init_ops(py::module_& m) {
py::pos_only(),
py::kw_only(),
R"pbdoc(
savez_compressed(file: str, *args, **kwargs)
Save several arrays to a binary file in compressed ``.npz`` format.
Args:
@ -2457,6 +2635,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
Args:
@ -2481,6 +2661,8 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
Select from ``x`` or ``y`` according to ``condition``.
The condition and input arrays must be the same shape or broadcastable

View File

@ -15,8 +15,8 @@ namespace py = pybind11;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray =
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>;
using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
@ -32,6 +32,14 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
return axes;
}
inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>();
} else {
return obj.cast<array>();
}
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {
@ -48,7 +56,7 @@ inline array to_array(
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64);
} else {
return std::get<array>(v);
return to_array_with_accessor(std::get<py::object>(v));
}
}
@ -60,13 +68,16 @@ inline std::pair<array, array> to_arrays(
// - If a is an array but b is not, treat b as a weak python type
// - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone
if (auto pa = std::get_if<array>(&a); pa) {
if (auto pb = std::get_if<array>(&b); pb) {
return {*pa, *pb};
if (auto pa = std::get_if<py::object>(&a); pa) {
auto arr_a = to_array_with_accessor(*pa);
if (auto pb = std::get_if<py::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {arr_a, arr_b};
}
return {*pa, to_array(b, pa->dtype())};
} else if (auto pb = std::get_if<array>(&b); pb) {
return {to_array(a, pb->dtype()), *pb};
return {arr_a, to_array(b, arr_a.dtype())};
} else if (auto pb = std::get_if<py::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {to_array(a, arr_b.dtype()), arr_b};
} else {
return {to_array(a), to_array(b)};
}