updated python bindings

This commit is contained in:
Gabrijel Boduljak 2023-12-21 19:09:36 +01:00 committed by Awni Hannun
parent b996d682d9
commit fa096d64a2

View File

@ -27,62 +27,125 @@ void init_linalg(py::module_& parent_module) {
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
m.def(
"vector_norm",
[](const array& a,
const std::variant<double, std::string>& ord,
const std::variant<std::monostate, int, std::vector<int>>& axis,
bool keepdims,
StreamOrDevice s) {
std::vector<int> axes = std::visit(
overloaded{
[](std::monostate s) { return std::vector<int>(); },
[](int axis) { return std::vector<int>({axis}); },
[](const std::vector<int> axes) { return axes; }},
axis);
if (axes.empty())
return vector_norm(a, ord, keepdims, s);
else
return vector_norm(a, ord, axes, keepdims, s);
"norm",
[](const array& a, const bool keepdims, const StreamOrDevice stream) {
return norm(a, {}, keepdims, stream);
},
"a"_a,
"ord"_a = 2.0,
"axis"_a = none,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc(
Computes a vector norm.
R"pbdoc()pbdoc");
- If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed.
- If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions
and the other dimensions will be treated as batch dimensions.
:attr:`ord` defines the vector norm that is computed. The following norms are supported:
====================== ===============================
:attr:`ord` vector norm
====================== ===============================
`2` (default) `2`-norm (see below)
`inf` `max(abs(x))`
`-inf` `min(abs(x))`
`0` `sum(x != 0)`
other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}`
====================== ===============================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
a (Tensor): tensor, flattened by default, but this behavior can be
controlled using :attr:`dim`.
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`
axis (int, Tuple[int], optional): dimensions over which to compute
the norm. See above for the behavior when :attr:`dim`\ `= None`.
Default: `None`
keepdims (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor, even when :attr:`a` is complex.
)pbdoc");
m.def(
"norm",
[](const array& a,
const int axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, {axis}, keepdims, stream);
},
"a"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const std::vector<int>& axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, axis, keepdims, stream);
},
"a"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const double ord,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, {}, keepdims, stream);
},
"a"_a,
"ord"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const double ord,
const int axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, {axis}, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const double ord,
const std::vector<int>& axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, axis, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const std::string& ord,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, {}, keepdims, stream);
},
"a"_a,
"ord"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const std::string& ord,
const int axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, {axis}, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
m.def(
"norm",
[](const array& a,
const std::string& ord,
const std::vector<int>& axis,
const bool keepdims,
const StreamOrDevice stream) {
return norm(a, ord, axis, keepdims, stream);
},
"a"_a,
"ord"_a,
"axis"_a,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc()pbdoc");
}