mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
updated python bindings
This commit is contained in:
parent
b996d682d9
commit
fa096d64a2
@ -27,62 +27,125 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
|
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"vector_norm",
|
"norm",
|
||||||
[](const array& a,
|
[](const array& a, const bool keepdims, const StreamOrDevice stream) {
|
||||||
const std::variant<double, std::string>& ord,
|
return norm(a, {}, keepdims, stream);
|
||||||
const std::variant<std::monostate, int, std::vector<int>>& axis,
|
|
||||||
bool keepdims,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
std::vector<int> axes = std::visit(
|
|
||||||
overloaded{
|
|
||||||
[](std::monostate s) { return std::vector<int>(); },
|
|
||||||
[](int axis) { return std::vector<int>({axis}); },
|
|
||||||
[](const std::vector<int> axes) { return axes; }},
|
|
||||||
axis);
|
|
||||||
|
|
||||||
if (axes.empty())
|
|
||||||
return vector_norm(a, ord, keepdims, s);
|
|
||||||
else
|
|
||||||
return vector_norm(a, ord, axes, keepdims, s);
|
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"ord"_a = 2.0,
|
|
||||||
"axis"_a = none,
|
|
||||||
"keepdims"_a = false,
|
"keepdims"_a = false,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc()pbdoc");
|
||||||
Computes a vector norm.
|
|
||||||
|
|
||||||
- If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed.
|
m.def(
|
||||||
- If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions
|
"norm",
|
||||||
and the other dimensions will be treated as batch dimensions.
|
[](const array& a,
|
||||||
|
const int axis,
|
||||||
|
const bool keepdims,
|
||||||
:attr:`ord` defines the vector norm that is computed. The following norms are supported:
|
const StreamOrDevice stream) {
|
||||||
|
return norm(a, {axis}, keepdims, stream);
|
||||||
====================== ===============================
|
},
|
||||||
:attr:`ord` vector norm
|
"a"_a,
|
||||||
====================== ===============================
|
"axis"_a,
|
||||||
`2` (default) `2`-norm (see below)
|
"keepdims"_a = false,
|
||||||
`inf` `max(abs(x))`
|
"stream"_a = none,
|
||||||
`-inf` `min(abs(x))`
|
R"pbdoc()pbdoc");
|
||||||
`0` `sum(x != 0)`
|
m.def(
|
||||||
other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}`
|
"norm",
|
||||||
====================== ===============================
|
[](const array& a,
|
||||||
|
const std::vector<int>& axis,
|
||||||
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
Args:
|
return norm(a, axis, keepdims, stream);
|
||||||
a (Tensor): tensor, flattened by default, but this behavior can be
|
},
|
||||||
controlled using :attr:`dim`.
|
"a"_a,
|
||||||
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`
|
"axis"_a,
|
||||||
axis (int, Tuple[int], optional): dimensions over which to compute
|
"keepdims"_a = false,
|
||||||
the norm. See above for the behavior when :attr:`dim`\ `= None`.
|
"stream"_a = none,
|
||||||
Default: `None`
|
R"pbdoc()pbdoc");
|
||||||
keepdims (bool, optional): If set to `True`, the reduced dimensions are retained
|
m.def(
|
||||||
in the result as dimensions with size one. Default: `False`
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
Returns:
|
const double ord,
|
||||||
A real-valued tensor, even when :attr:`a` is complex.
|
const bool keepdims,
|
||||||
)pbdoc");
|
const StreamOrDevice stream) {
|
||||||
|
return norm(a, ord, {}, keepdims, stream);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"ord"_a,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc()pbdoc");
|
||||||
|
m.def(
|
||||||
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
|
const double ord,
|
||||||
|
const int axis,
|
||||||
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
|
return norm(a, ord, {axis}, keepdims, stream);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"ord"_a,
|
||||||
|
"axis"_a,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc()pbdoc");
|
||||||
|
m.def(
|
||||||
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
|
return norm(a, ord, axis, keepdims, stream);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"ord"_a,
|
||||||
|
"axis"_a,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc()pbdoc");
|
||||||
|
m.def(
|
||||||
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
|
return norm(a, ord, {}, keepdims, stream);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"ord"_a,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc()pbdoc");
|
||||||
|
m.def(
|
||||||
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const int axis,
|
||||||
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
|
return norm(a, ord, {axis}, keepdims, stream);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"ord"_a,
|
||||||
|
"axis"_a,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc()pbdoc");
|
||||||
|
m.def(
|
||||||
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
|
return norm(a, ord, axis, keepdims, stream);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"ord"_a,
|
||||||
|
"axis"_a,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc()pbdoc");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user