diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index b389ac15f..c2728c738 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -27,62 +27,125 @@ void init_linalg(py::module_& parent_module) { parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); m.def( - "vector_norm", - [](const array& a, - const std::variant& ord, - const std::variant>& axis, - bool keepdims, - StreamOrDevice s) { - std::vector axes = std::visit( - overloaded{ - [](std::monostate s) { return std::vector(); }, - [](int axis) { return std::vector({axis}); }, - [](const std::vector axes) { return axes; }}, - axis); - - if (axes.empty()) - return vector_norm(a, ord, keepdims, s); - else - return vector_norm(a, ord, axes, keepdims, s); + "norm", + [](const array& a, const bool keepdims, const StreamOrDevice stream) { + return norm(a, {}, keepdims, stream); }, "a"_a, - "ord"_a = 2.0, - "axis"_a = none, "keepdims"_a = false, "stream"_a = none, - R"pbdoc( - Computes a vector norm. + R"pbdoc()pbdoc"); - - If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed. - - If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions - and the other dimensions will be treated as batch dimensions. - - - :attr:`ord` defines the vector norm that is computed. The following norms are supported: - - ====================== =============================== - :attr:`ord` vector norm - ====================== =============================== - `2` (default) `2`-norm (see below) - `inf` `max(abs(x))` - `-inf` `min(abs(x))` - `0` `sum(x != 0)` - other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` - ====================== =============================== - - where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. - - Args: - a (Tensor): tensor, flattened by default, but this behavior can be - controlled using :attr:`dim`. - ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` - axis (int, Tuple[int], optional): dimensions over which to compute - the norm. See above for the behavior when :attr:`dim`\ `= None`. - Default: `None` - keepdims (bool, optional): If set to `True`, the reduced dimensions are retained - in the result as dimensions with size one. Default: `False` - - Returns: - A real-valued tensor, even when :attr:`a` is complex. - )pbdoc"); + m.def( + "norm", + [](const array& a, + const int axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, {axis}, keepdims, stream); + }, + "a"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::vector& axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, axis, keepdims, stream); + }, + "a"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const double ord, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const double ord, + const int axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {axis}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const double ord, + const std::vector& axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, axis, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::string& ord, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::string& ord, + const int axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {axis}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::string& ord, + const std::vector& axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, axis, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); }