diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 5ab8a50bf..1ad9d207d 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -11,6 +11,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp new file mode 100644 index 000000000..b389ac15f --- /dev/null +++ b/python/src/linalg.cpp @@ -0,0 +1,88 @@ + +// Copyright © 2023 Apple Inc. + +#include +#include +#include + +#include +#include +#include + +#include "mlx/linalg.h" +#include "mlx/ops.h" +#include "mlx/utils.h" + +#include "python/src/load.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; + +using namespace mlx::core; +using namespace mlx::core::linalg; + +void init_linalg(py::module_& parent_module) { + auto m = + parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); + + m.def( + "vector_norm", + [](const array& a, + const std::variant& ord, + const std::variant>& axis, + bool keepdims, + StreamOrDevice s) { + std::vector axes = std::visit( + overloaded{ + [](std::monostate s) { return std::vector(); }, + [](int axis) { return std::vector({axis}); }, + [](const std::vector axes) { return axes; }}, + axis); + + if (axes.empty()) + return vector_norm(a, ord, keepdims, s); + else + return vector_norm(a, ord, axes, keepdims, s); + }, + "a"_a, + "ord"_a = 2.0, + "axis"_a = none, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc( + Computes a vector norm. + + - If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed. + - If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions + and the other dimensions will be treated as batch dimensions. + + + :attr:`ord` defines the vector norm that is computed. The following norms are supported: + + ====================== =============================== + :attr:`ord` vector norm + ====================== =============================== + `2` (default) `2`-norm (see below) + `inf` `max(abs(x))` + `-inf` `min(abs(x))` + `0` `sum(x != 0)` + other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` + ====================== =============================== + + where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + + Args: + a (Tensor): tensor, flattened by default, but this behavior can be + controlled using :attr:`dim`. + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` + axis (int, Tuple[int], optional): dimensions over which to compute + the norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdims (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + + Returns: + A real-valued tensor, even when :attr:`a` is complex. + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ebadf767d..d7cf15751 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -15,6 +15,7 @@ void init_ops(py::module_&); void init_transforms(py::module_&); void init_random(py::module_&); void init_fft(py::module_&); +void init_linalg(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) { init_transforms(m); init_random(m); init_fft(m); + init_linalg(m); m.attr("__version__") = TOSTRING(_VERSION_); }