mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Mlx array accessor (#128)
* Add an accessor to interoperate with custom types * Change the docs to custom signatures
This commit is contained in:
parent
072044e28f
commit
3214629601
@ -458,38 +458,54 @@ void init_array(py::module_& m) {
|
||||
m.attr("bfloat16") = py::cast(bfloat16);
|
||||
m.attr("complex64") = py::cast(complex64);
|
||||
|
||||
py::class_<array>(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc")
|
||||
.def(
|
||||
py::init([](ScalarOrArray v, std::optional<Dtype> t) {
|
||||
auto arr = to_array(v, t);
|
||||
auto array_class = py::class_<array>(
|
||||
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
|
||||
|
||||
{
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
array_class.def(
|
||||
py::init([](std::variant<
|
||||
py::bool_,
|
||||
py::int_,
|
||||
py::float_,
|
||||
std::complex<float>,
|
||||
py::list,
|
||||
py::tuple,
|
||||
py::array,
|
||||
py::buffer,
|
||||
py::object> v,
|
||||
std::optional<Dtype> t) {
|
||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||
return array(py::cast<bool>(*pv), t.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||
return array(py::cast<int>(*pv), t.value_or(int32));
|
||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||
return array(py::cast<float>(*pv), t.value_or(float32));
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||
} else if (auto pv = std::get_if<py::list>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else {
|
||||
auto arr = to_array_with_accessor(std::get<py::object>(v));
|
||||
return astype(arr, t.value_or(arr.dtype()));
|
||||
}),
|
||||
"val"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
.def(
|
||||
py::init([](std::variant<py::list, py::tuple> pl,
|
||||
std::optional<Dtype> dtype) {
|
||||
if (auto pv = std::get_if<py::list>(&pl); pv) {
|
||||
return array_from_list(*pv, dtype);
|
||||
} else {
|
||||
auto v = std::get<py::tuple>(pl);
|
||||
return array_from_list(v, dtype);
|
||||
}
|
||||
}),
|
||||
"vals"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
.def(
|
||||
py::init([](py::array np_array, std::optional<Dtype> dtype) {
|
||||
return np_array_to_mlx(np_array, dtype);
|
||||
}),
|
||||
"vals"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
.def(
|
||||
py::init([](py::buffer np_buffer, std::optional<Dtype> dtype) {
|
||||
return np_array_to_mlx(np_buffer, dtype);
|
||||
}),
|
||||
"vals"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
}
|
||||
}),
|
||||
"val"_a,
|
||||
"dtype"_a = std::nullopt,
|
||||
R"pbdoc(
|
||||
__init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
array_class
|
||||
.def_property_readonly(
|
||||
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
||||
.def_property_readonly(
|
||||
|
@ -36,6 +36,9 @@ double scalar_to_double(Scalar s) {
|
||||
}
|
||||
|
||||
void init_ops(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
m.def(
|
||||
"reshape",
|
||||
&reshape,
|
||||
@ -45,6 +48,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
reshape(a: array, /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Reshape an array while preserving the size.
|
||||
|
||||
Args:
|
||||
@ -73,6 +78,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
squeeze(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Remove length one axes from an array.
|
||||
|
||||
Args:
|
||||
@ -100,6 +107,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
expand_dims(a: array, /, axis: Union[int, List[int]], *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Add a size one dimension at the given axis.
|
||||
|
||||
Args:
|
||||
@ -117,6 +126,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise absolute value.
|
||||
|
||||
Args:
|
||||
@ -133,6 +144,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise sign.
|
||||
|
||||
Args:
|
||||
@ -149,6 +162,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise negation.
|
||||
|
||||
Args:
|
||||
@ -169,6 +184,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise addition.
|
||||
|
||||
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays
|
||||
@ -193,6 +210,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise subtraction.
|
||||
|
||||
Subtract one array from another with numpy-style broadcasting semantics. Either or both
|
||||
@ -217,6 +236,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise multiplication.
|
||||
|
||||
Multiply two arrays with numpy-style broadcasting semantics. Either or both
|
||||
@ -241,6 +262,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise division.
|
||||
|
||||
Divide two arrays with numpy-style broadcasting semantics. Either or both
|
||||
@ -265,6 +288,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise remainder of division.
|
||||
|
||||
Computes the remainder of dividing a with b with numpy-style
|
||||
@ -290,6 +315,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise equality.
|
||||
|
||||
Equality comparison on two arrays with numpy-style broadcasting semantics.
|
||||
@ -314,6 +341,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise not equal.
|
||||
|
||||
Not equal comparison on two arrays with numpy-style broadcasting semantics.
|
||||
@ -338,6 +367,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise less than.
|
||||
|
||||
Strict less than on two arrays with numpy-style broadcasting semantics.
|
||||
@ -362,6 +393,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise less than or equal.
|
||||
|
||||
Less than or equal on two arrays with numpy-style broadcasting semantics.
|
||||
@ -386,6 +419,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise greater than.
|
||||
|
||||
Strict greater than on two arrays with numpy-style broadcasting semantics.
|
||||
@ -410,6 +445,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise greater or equal.
|
||||
|
||||
Greater than or equal on two arrays with numpy-style broadcasting semantics.
|
||||
@ -438,6 +475,8 @@ void init_ops(py::module_& m) {
|
||||
"equal_nan"_a = false,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Array equality check.
|
||||
|
||||
Compare two arrays for equality. Returns ``True`` if and only if the arrays
|
||||
@ -462,6 +501,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Matrix multiplication.
|
||||
|
||||
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
|
||||
@ -492,6 +533,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise square.
|
||||
|
||||
Args:
|
||||
@ -508,6 +551,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise square root.
|
||||
|
||||
Args:
|
||||
@ -524,6 +569,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise reciprocal and square root.
|
||||
|
||||
Args:
|
||||
@ -540,6 +587,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise reciprocal.
|
||||
|
||||
Args:
|
||||
@ -558,6 +607,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise logical not.
|
||||
|
||||
Args:
|
||||
@ -578,6 +629,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise log-add-exp.
|
||||
|
||||
This is a numerically stable log-add-exp of two arrays with numpy-style
|
||||
@ -600,6 +653,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise exponential.
|
||||
|
||||
Args:
|
||||
@ -616,6 +671,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise error function.
|
||||
|
||||
.. math::
|
||||
@ -635,6 +692,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise inverse of :func:`erf`.
|
||||
|
||||
Args:
|
||||
@ -651,6 +710,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise sine.
|
||||
|
||||
Args:
|
||||
@ -667,6 +728,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise cosine.
|
||||
|
||||
Args:
|
||||
@ -683,6 +746,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise tangent.
|
||||
|
||||
Args:
|
||||
@ -699,6 +764,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise inverse sine.
|
||||
|
||||
Args:
|
||||
@ -715,6 +782,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise inverse cosine.
|
||||
|
||||
Args:
|
||||
@ -731,6 +800,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise inverse tangent.
|
||||
|
||||
Args:
|
||||
@ -747,6 +818,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise hyperbolic sine.
|
||||
|
||||
Args:
|
||||
@ -763,6 +836,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise hyperbolic cosine.
|
||||
|
||||
Args:
|
||||
@ -779,6 +854,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise hyperbolic tangent.
|
||||
|
||||
Args:
|
||||
@ -795,6 +872,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise inverse hyperbolic sine.
|
||||
|
||||
Args:
|
||||
@ -811,6 +890,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise inverse hyperbolic cosine.
|
||||
|
||||
Args:
|
||||
@ -827,6 +908,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise inverse hyperbolic tangent.
|
||||
|
||||
Args:
|
||||
@ -843,6 +926,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise natural logarithm.
|
||||
|
||||
Args:
|
||||
@ -859,6 +944,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise base-2 logarithm.
|
||||
|
||||
Args:
|
||||
@ -875,6 +962,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise base-10 logarithm.
|
||||
|
||||
Args:
|
||||
@ -891,6 +980,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise natural log of one plus the array.
|
||||
|
||||
Args:
|
||||
@ -907,6 +998,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Stop gradients from being computed.
|
||||
|
||||
The operation is the identity but it prevents gradients from flowing
|
||||
@ -927,6 +1020,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise logistic sigmoid.
|
||||
|
||||
The logistic sigmoid function is:
|
||||
@ -952,6 +1047,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise power operation.
|
||||
|
||||
Raise the elements of a to the powers in elements of b with numpy-style
|
||||
@ -964,108 +1061,102 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: Bases of ``a`` raised to powers in ``b``.
|
||||
)pbdoc");
|
||||
{
|
||||
// Disable function signature just for arange which we write manually
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
|
||||
Dtype dtype =
|
||||
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
|
||||
Dtype dtype =
|
||||
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
|
||||
|
||||
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none);
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
|
||||
return arange(
|
||||
scalar_to_double(start), scalar_to_double(stop), dtype, s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none);
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
|
||||
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none);
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
|
||||
return arange(
|
||||
scalar_to_double(start), scalar_to_double(stop), dtype, s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none);
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
|
||||
|
||||
return arange(
|
||||
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none);
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
// Determine the final dtype based on input types
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(
|
||||
scalar_to_dtype(start),
|
||||
promote_types(
|
||||
scalar_to_dtype(stop), scalar_to_dtype(step)));
|
||||
return arange(
|
||||
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none);
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
// Determine the final dtype based on input types
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(
|
||||
scalar_to_dtype(start),
|
||||
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
|
||||
|
||||
return arange(
|
||||
scalar_to_double(start),
|
||||
scalar_to_double(stop),
|
||||
scalar_to_double(step),
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
return arange(
|
||||
scalar_to_double(start),
|
||||
scalar_to_double(stop),
|
||||
scalar_to_double(step),
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"dtype"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Generates ranges of numbers.
|
||||
Generates ranges of numbers.
|
||||
|
||||
Generate numbers in the half-open interval ``[start, stop)`` in
|
||||
increments of ``step``.
|
||||
Generate numbers in the half-open interval ``[start, stop)`` in
|
||||
increments of ``step``.
|
||||
|
||||
Args:
|
||||
start (float or int, optional): Starting value which defaults to ``0``.
|
||||
stop (float or int): Stopping value.
|
||||
step (float or int, optional): Increment which defaults to ``1``.
|
||||
dtype (Dtype, optional): Specifies the data type of the output.
|
||||
If unspecified will default to ``float32`` if any of ``start``,
|
||||
``stop``, or ``step`` are ``float``. Otherwise will default to
|
||||
``int32``.
|
||||
Args:
|
||||
start (float or int, optional): Starting value which defaults to ``0``.
|
||||
stop (float or int): Stopping value.
|
||||
step (float or int, optional): Increment which defaults to ``1``.
|
||||
dtype (Dtype, optional): Specifies the data type of the output.
|
||||
If unspecified will default to ``float32`` if any of ``start``,
|
||||
``stop``, or ``step`` are ``float``. Otherwise will default to
|
||||
``int32``.
|
||||
|
||||
Returns:
|
||||
array: The range of values.
|
||||
Returns:
|
||||
array: The range of values.
|
||||
|
||||
Note:
|
||||
Following the Numpy convention the actual increment used to
|
||||
generate numbers is ``dtype(start + step) - dtype(start)``.
|
||||
This can lead to unexpected results for example if `start + step`
|
||||
is a fractional value and the `dtype` is integral.
|
||||
Note:
|
||||
Following the Numpy convention the actual increment used to
|
||||
generate numbers is ``dtype(start + step) - dtype(start)``.
|
||||
This can lead to unexpected results for example if `start + step`
|
||||
is a fractional value and the `dtype` is integral.
|
||||
)pbdoc");
|
||||
}
|
||||
m.def(
|
||||
"take",
|
||||
[](const array& a,
|
||||
@ -1085,6 +1176,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Take elements along an axis.
|
||||
|
||||
The elements are taken from ``indices`` along the specified axis.
|
||||
@ -1121,6 +1214,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Take values along an axis at the specified indices.
|
||||
|
||||
Args:
|
||||
@ -1153,6 +1248,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Construct an array with the given value.
|
||||
|
||||
Constructs an array of size ``shape`` filled with ``vals``. If ``vals``
|
||||
@ -1184,6 +1281,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Construct an array of zeros.
|
||||
|
||||
Args:
|
||||
@ -1202,6 +1301,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An array of zeros like the input.
|
||||
|
||||
Args:
|
||||
@ -1227,6 +1328,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Construct an array of ones.
|
||||
|
||||
Args:
|
||||
@ -1245,6 +1348,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An array of ones like the input.
|
||||
|
||||
Args:
|
||||
@ -1312,6 +1417,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Approximate comparison of two arrays.
|
||||
|
||||
The arrays are considered equal if:
|
||||
@ -1347,6 +1454,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An `and` reduction over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1375,6 +1484,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An `or` reduction over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1397,8 +1508,11 @@ void init_ops(py::module_& m) {
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise minimum.
|
||||
|
||||
Take the element-wise min of two arrays with numpy-style broadcasting
|
||||
@ -1423,6 +1537,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise maximum.
|
||||
|
||||
Take the element-wise max of two arrays with numpy-style broadcasting
|
||||
@ -1452,6 +1568,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Transpose the dimensions of the array.
|
||||
|
||||
Args:
|
||||
@ -1477,6 +1595,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Sum reduce the array over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1505,6 +1625,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An product reduction over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1533,6 +1655,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An `min` reduction over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1561,6 +1685,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
An `max` reduction over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1589,6 +1715,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
A `log-sum-exp` reduction over the given axes.
|
||||
|
||||
The log-sum-exp reduction is a numerically stable version of:
|
||||
@ -1623,6 +1751,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Compute the mean(s) over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1653,6 +1783,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Compute the variance(s) over the given axes.
|
||||
|
||||
Args:
|
||||
@ -1688,6 +1820,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Split an array along a given axis.
|
||||
|
||||
Args:
|
||||
@ -1721,6 +1855,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Indices of the minimum values along the axis.
|
||||
|
||||
Args:
|
||||
@ -1752,6 +1888,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Indices of the maximum values along the axis.
|
||||
|
||||
Args:
|
||||
@ -1779,6 +1917,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Returns a sorted copy of the array.
|
||||
|
||||
Args:
|
||||
@ -1805,6 +1945,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Returns the indices that sort the array.
|
||||
|
||||
Args:
|
||||
@ -1832,6 +1974,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Returns a partitioned copy of the array such that the smaller ``kth``
|
||||
elements are first.
|
||||
|
||||
@ -1866,6 +2010,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Returns the indices that partition the array.
|
||||
|
||||
The ordering of the elements within a partition in given by the indices
|
||||
@ -1901,6 +2047,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Returns the ``k`` largest elements from the input along a given axis.
|
||||
|
||||
The elements will not necessarily be in sorted order.
|
||||
@ -1926,6 +2074,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Broadcast an array to the given shape.
|
||||
|
||||
The broadcasting semantics are the same as Numpy.
|
||||
@ -1948,6 +2098,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Perform the softmax along the given axis.
|
||||
|
||||
This operation is a numerically stable version of:
|
||||
@ -1982,6 +2134,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Concatenate the arrays along the given axis.
|
||||
|
||||
Args:
|
||||
@ -2024,6 +2178,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Pad an array with a constant value
|
||||
|
||||
Args:
|
||||
@ -2067,6 +2223,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Create a view into the array with the given shape and strides.
|
||||
|
||||
The resulting array will always be as if the provided array was row
|
||||
@ -2113,6 +2271,8 @@ void init_ops(py::module_& m) {
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return the cumulative sum of the elements along the given axis.
|
||||
|
||||
Args:
|
||||
@ -2145,6 +2305,8 @@ void init_ops(py::module_& m) {
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return the cumulative product of the elements along the given axis.
|
||||
|
||||
Args:
|
||||
@ -2177,6 +2339,8 @@ void init_ops(py::module_& m) {
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return the cumulative maximum of the elements along the given axis.
|
||||
|
||||
Args:
|
||||
@ -2209,6 +2373,8 @@ void init_ops(py::module_& m) {
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return the cumulative minimum of the elements along the given axis.
|
||||
|
||||
Args:
|
||||
@ -2275,6 +2441,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
The discrete convolution of 1D arrays.
|
||||
|
||||
If ``v`` is longer than ``a``, then they are swapped.
|
||||
@ -2301,6 +2469,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
1D convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
@ -2360,6 +2530,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
2D convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
@ -2390,6 +2562,8 @@ void init_ops(py::module_& m) {
|
||||
"retain_graph"_a = true,
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
save(file: str, arr: array, / , retain_graph: bool = True)
|
||||
|
||||
Save the array to a binary file in ``.npy`` format.
|
||||
|
||||
Args:
|
||||
@ -2408,6 +2582,8 @@ void init_ops(py::module_& m) {
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
savez(file: str, *args, **kwargs)
|
||||
|
||||
Save several arrays to a binary file in uncompressed ``.npz`` format.
|
||||
|
||||
.. code-block:: python
|
||||
@ -2440,6 +2616,8 @@ void init_ops(py::module_& m) {
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
savez_compressed(file: str, *args, **kwargs)
|
||||
|
||||
Save several arrays to a binary file in compressed ``.npz`` format.
|
||||
|
||||
Args:
|
||||
@ -2457,6 +2635,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||
|
||||
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
||||
|
||||
Args:
|
||||
@ -2481,6 +2661,8 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Select from ``x`` or ``y`` according to ``condition``.
|
||||
|
||||
The condition and input arrays must be the same shape or broadcastable
|
||||
|
@ -15,8 +15,8 @@ namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
using ScalarOrArray =
|
||||
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>;
|
||||
using ScalarOrArray = std::
|
||||
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||
static constexpr std::monostate none{};
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
@ -32,6 +32,14 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
return axes;
|
||||
}
|
||||
|
||||
inline array to_array_with_accessor(py::object obj) {
|
||||
if (py::hasattr(obj, "__mlx_array__")) {
|
||||
return obj.attr("__mlx_array__")().cast<array>();
|
||||
} else {
|
||||
return obj.cast<array>();
|
||||
}
|
||||
}
|
||||
|
||||
inline array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt) {
|
||||
@ -48,7 +56,7 @@ inline array to_array(
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), complex64);
|
||||
} else {
|
||||
return std::get<array>(v);
|
||||
return to_array_with_accessor(std::get<py::object>(v));
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,13 +68,16 @@ inline std::pair<array, array> to_arrays(
|
||||
// - If a is an array but b is not, treat b as a weak python type
|
||||
// - If b is an array but a is not, treat a as a weak python type
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
if (auto pa = std::get_if<array>(&a); pa) {
|
||||
if (auto pb = std::get_if<array>(&b); pb) {
|
||||
return {*pa, *pb};
|
||||
if (auto pa = std::get_if<py::object>(&a); pa) {
|
||||
auto arr_a = to_array_with_accessor(*pa);
|
||||
if (auto pb = std::get_if<py::object>(&b); pb) {
|
||||
auto arr_b = to_array_with_accessor(*pb);
|
||||
return {arr_a, arr_b};
|
||||
}
|
||||
return {*pa, to_array(b, pa->dtype())};
|
||||
} else if (auto pb = std::get_if<array>(&b); pb) {
|
||||
return {to_array(a, pb->dtype()), *pb};
|
||||
return {arr_a, to_array(b, arr_a.dtype())};
|
||||
} else if (auto pb = std::get_if<py::object>(&b); pb) {
|
||||
auto arr_b = to_array_with_accessor(*pb);
|
||||
return {to_array(a, arr_b.dtype()), arr_b};
|
||||
} else {
|
||||
return {to_array(a), to_array(b)};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user