implemented vector_norm python binding

This commit is contained in:
Gabrijel Boduljak 2023-12-17 07:06:04 +01:00 committed by Awni Hannun
parent cc9b2dc3c2
commit 24da85025f
3 changed files with 91 additions and 0 deletions

View File

@ -11,6 +11,7 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
) )
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)

88
python/src/linalg.cpp Normal file
View File

@ -0,0 +1,88 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include <ostream>
#include <variant>
#include <pybind11/iostream.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
void init_linalg(py::module_& parent_module) {
auto m =
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
m.def(
"vector_norm",
[](const array& a,
const std::variant<double, std::string>& ord,
const std::variant<std::monostate, int, std::vector<int>>& axis,
bool keepdims,
StreamOrDevice s) {
std::vector<int> axes = std::visit(
overloaded{
[](std::monostate s) { return std::vector<int>(); },
[](int axis) { return std::vector<int>({axis}); },
[](const std::vector<int> axes) { return axes; }},
axis);
if (axes.empty())
return vector_norm(a, ord, keepdims, s);
else
return vector_norm(a, ord, axes, keepdims, s);
},
"a"_a,
"ord"_a = 2.0,
"axis"_a = none,
"keepdims"_a = false,
"stream"_a = none,
R"pbdoc(
Computes a vector norm.
- If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed.
- If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions
and the other dimensions will be treated as batch dimensions.
:attr:`ord` defines the vector norm that is computed. The following norms are supported:
====================== ===============================
:attr:`ord` vector norm
====================== ===============================
`2` (default) `2`-norm (see below)
`inf` `max(abs(x))`
`-inf` `min(abs(x))`
`0` `sum(x != 0)`
other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}`
====================== ===============================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
a (Tensor): tensor, flattened by default, but this behavior can be
controlled using :attr:`dim`.
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`
axis (int, Tuple[int], optional): dimensions over which to compute
the norm. See above for the behavior when :attr:`dim`\ `= None`.
Default: `None`
keepdims (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor, even when :attr:`a` is complex.
)pbdoc");
}

View File

@ -15,6 +15,7 @@ void init_ops(py::module_&);
void init_transforms(py::module_&); void init_transforms(py::module_&);
void init_random(py::module_&); void init_random(py::module_&);
void init_fft(py::module_&); void init_fft(py::module_&);
void init_linalg(py::module_&);
PYBIND11_MODULE(core, m) { PYBIND11_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon."; m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) {
init_transforms(m); init_transforms(m);
init_random(m); init_random(m);
init_fft(m); init_fft(m);
init_linalg(m);
m.attr("__version__") = TOSTRING(_VERSION_); m.attr("__version__") = TOSTRING(_VERSION_);
} }