updated python bindings

This commit is contained in:
Gabrijel Boduljak 2023-12-21 19:09:36 +01:00 committed by Awni Hannun
parent b996d682d9
commit fa096d64a2

View File

@ -27,62 +27,125 @@ void init_linalg(py::module_& parent_module) {
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
m.def( m.def(
"vector_norm", "norm",
[](const array& a, [](const array& a, const bool keepdims, const StreamOrDevice stream) {
const std::variant<double, std::string>& ord, return norm(a, {}, keepdims, stream);
const std::variant<std::monostate, int, std::vector<int>>& axis,
bool keepdims,
StreamOrDevice s) {
std::vector<int> axes = std::visit(
overloaded{
[](std::monostate s) { return std::vector<int>(); },
[](int axis) { return std::vector<int>({axis}); },
[](const std::vector<int> axes) { return axes; }},
axis);
if (axes.empty())
return vector_norm(a, ord, keepdims, s);
else
return vector_norm(a, ord, axes, keepdims, s);
}, },
"a"_a, "a"_a,
"ord"_a = 2.0,
"axis"_a = none,
"keepdims"_a = false, "keepdims"_a = false,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc()pbdoc");
Computes a vector norm.
- If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed. m.def(
- If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions "norm",
and the other dimensions will be treated as batch dimensions. [](const array& a,
const int axis,
const bool keepdims,
:attr:`ord` defines the vector norm that is computed. The following norms are supported: const StreamOrDevice stream) {
return norm(a, {axis}, keepdims, stream);
====================== =============================== },
:attr:`ord` vector norm "a"_a,
====================== =============================== "axis"_a,
`2` (default) `2`-norm (see below) "keepdims"_a = false,
`inf` `max(abs(x))` "stream"_a = none,
`-inf` `min(abs(x))` R"pbdoc()pbdoc");
`0` `sum(x != 0)` m.def(
other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` "norm",
====================== =============================== [](const array& a,
const std::vector<int>& axis,
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. const bool keepdims,
const StreamOrDevice stream) {
Args: return norm(a, axis, keepdims, stream);
a (Tensor): tensor, flattened by default, but this behavior can be },
controlled using :attr:`dim`. "a"_a,
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` "axis"_a,
axis (int, Tuple[int], optional): dimensions over which to compute "keepdims"_a = false,
the norm. See above for the behavior when :attr:`dim`\ `= None`. "stream"_a = none,
Default: `None` R"pbdoc()pbdoc");
keepdims (bool, optional): If set to `True`, the reduced dimensions are retained m.def(
in the result as dimensions with size one. Default: `False` "norm",
[](const array& a,
Returns: const double ord,
A real-valued tensor, even when :attr:`a` is complex. const bool keepdims,
)pbdoc"); const StreamOrDevice stream) {
return norm(a, ord, {}, keepdims, stream);
},
"a"_a,
"ord"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const double ord,
const int axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, {axis}, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const double ord,
const std::vector<int>& axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, axis, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const std::string& ord,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, {}, keepdims, stream);
},
"a"_a,
"ord"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const std::string& ord,
const int axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, {axis}, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const std::string& ord,
const std::vector<int>& axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, axis, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
} }