mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
fix python tests
This commit is contained in:
parent
49e3e99da3
commit
67e319488c
@ -4,9 +4,7 @@
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
@ -48,25 +46,36 @@ inline array matrix_norm(
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto row_axis = axis[0];
|
||||
auto col_axis = axis[1];
|
||||
if (!keepdims && col_axis > row_axis && col_axis > 0) {
|
||||
col_axis -= 1;
|
||||
}
|
||||
if (ord == -1.0) {
|
||||
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||
return astype(
|
||||
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||
dtype,
|
||||
s);
|
||||
} else if (ord == 1.0) {
|
||||
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||
return astype(
|
||||
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||
dtype,
|
||||
s);
|
||||
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||
return astype(
|
||||
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||
dtype,
|
||||
s);
|
||||
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||
return astype(
|
||||
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||
dtype,
|
||||
s);
|
||||
} else if (ord == 2.0 || ord == -2.0) {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Singular value norms are not implemented.");
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm";
|
||||
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@ -78,13 +87,13 @@ inline array matrix_norm(
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == "f" || ord == "fro") {
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
} else if (ord == "nuc") {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm";
|
||||
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@ -100,7 +109,7 @@ array norm(
|
||||
|
||||
if (axis.value().size() > 2) {
|
||||
throw std::invalid_argument(
|
||||
"[linalg::norm] Received too many axes for norm");
|
||||
"[linalg::norm] Received too many axes for norm.");
|
||||
}
|
||||
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
|
||||
}
|
||||
@ -124,7 +133,7 @@ array norm(
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[linalg::norm] Received too many axes for norm");
|
||||
"[linalg::norm] Received too many axes for norm.");
|
||||
}
|
||||
}
|
||||
|
||||
|
17
mlx/linalg.h
17
mlx/linalg.h
@ -4,21 +4,22 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "ops.h"
|
||||
#include "stream.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
/*
|
||||
/**
|
||||
* Compute vector or matrix norms.
|
||||
*
|
||||
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
||||
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
||||
* - If axis is provided, but ord is not, then the 2-norm is computed along the
|
||||
* given axes. At most 2 axes can be specified.
|
||||
* - If both axis and ord are provided, then the corresponding matrix of vector
|
||||
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
|
||||
* for matrices) is computed along the given axes. At most 2 axes can be
|
||||
* specified.
|
||||
* - If both axis and ord are provided, then the corresponding matrix or vector
|
||||
* norm is computed. At most 2 axes can be specified.
|
||||
*/
|
||||
array norm(
|
||||
|
@ -20,8 +20,8 @@ void init_linalg(py::module_& parent_module) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
auto m =
|
||||
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
|
||||
auto m = parent_module.def_submodule(
|
||||
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||
|
||||
m.def(
|
||||
"norm",
|
||||
@ -72,8 +72,8 @@ void init_linalg(py::module_& parent_module) {
|
||||
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||
2-norm of ``a.flatten`` will be returned.
|
||||
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||
If ``None``, the 2-norm will be computed along the given ``axis``.
|
||||
Default: ``None``.
|
||||
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
|
||||
along the given ``axis``. Default: ``None``.
|
||||
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
|
||||
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
|
||||
|
@ -11,74 +11,56 @@ import numpy as np
|
||||
|
||||
class TestLinalg(mlx_tests.MLXTestCase):
|
||||
def test_norm(self):
|
||||
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")]
|
||||
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
# Test when at least one axis is provided
|
||||
for num_axes in range(1, len(shape)):
|
||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||
if num_axes == 1:
|
||||
ords = vector_ords
|
||||
else:
|
||||
ords = matrix_ords
|
||||
for keepdims in [True, False]:
|
||||
# Test axis provided, no ord provided
|
||||
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
|
||||
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
|
||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
# Test both ord and axis provided
|
||||
for o in ords:
|
||||
for keepdims in [True, False]:
|
||||
if o:
|
||||
out_np = np.linalg.norm(
|
||||
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
out_mx = mx.linalg.norm(
|
||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
else:
|
||||
out_np = np.linalg.norm(
|
||||
x_np, axis=axis, keepdims=keepdims
|
||||
)
|
||||
out_mx = mx.linalg.norm(
|
||||
x_mx, axis=axis, keepdims=keepdims
|
||||
)
|
||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
|
||||
# Test only axis provided
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
for num_axes in range(1, len(shape)):
|
||||
if num_axes == 1:
|
||||
ords = vector_ords
|
||||
else:
|
||||
ords = matrix_ords
|
||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||
for keepdims in [True, False]:
|
||||
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
|
||||
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
|
||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
for o in ords:
|
||||
out_np = np.linalg.norm(
|
||||
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
out_mx = mx.linalg.norm(
|
||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
with self.subTest(
|
||||
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||
):
|
||||
self.assertTrue(
|
||||
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
)
|
||||
|
||||
# Test only ord provided
|
||||
for shape in [(3,), (2, 3)]:
|
||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
||||
for keepdims in [True, False]:
|
||||
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
|
||||
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
|
||||
self.assertTrue(
|
||||
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
)
|
||||
|
||||
# Test no ord and no axis provided
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
||||
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
||||
for keepdims in [True, False]:
|
||||
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
for keepdims in [True, False]:
|
||||
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||
with self.subTest(shape=shape, keepdims=keepdims):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user