mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
fix python tests
This commit is contained in:
parent
49e3e99da3
commit
67e319488c
@ -4,9 +4,7 @@
|
|||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
@ -48,25 +46,36 @@ inline array matrix_norm(
|
|||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto row_axis = axis[0];
|
auto row_axis = axis[0];
|
||||||
auto col_axis = axis[1];
|
auto col_axis = axis[1];
|
||||||
if (!keepdims && col_axis > row_axis && col_axis > 0) {
|
|
||||||
col_axis -= 1;
|
|
||||||
}
|
|
||||||
if (ord == -1.0) {
|
if (ord == -1.0) {
|
||||||
|
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||||
return astype(
|
return astype(
|
||||||
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
dtype,
|
dtype,
|
||||||
s);
|
s);
|
||||||
} else if (ord == 1.0) {
|
} else if (ord == 1.0) {
|
||||||
|
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||||
return astype(
|
return astype(
|
||||||
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
dtype,
|
dtype,
|
||||||
s);
|
s);
|
||||||
|
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||||
|
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||||
|
return astype(
|
||||||
|
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||||
|
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||||
|
return astype(
|
||||||
|
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
} else if (ord == 2.0 || ord == -2.0) {
|
} else if (ord == 2.0 || ord == -2.0) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[linalg::norm] Singular value norms are not implemented.");
|
"[linalg::norm] Singular value norms are not implemented.");
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm";
|
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,13 +87,13 @@ inline array matrix_norm(
|
|||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (ord == "f" || ord == "fro") {
|
if (ord == "f" || ord == "fro") {
|
||||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||||
} else if (ord == "nuc") {
|
} else if (ord == "nuc") {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm";
|
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,7 +109,7 @@ array norm(
|
|||||||
|
|
||||||
if (axis.value().size() > 2) {
|
if (axis.value().size() > 2) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[linalg::norm] Received too many axes for norm");
|
"[linalg::norm] Received too many axes for norm.");
|
||||||
}
|
}
|
||||||
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
|
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
|
||||||
}
|
}
|
||||||
@ -124,7 +133,7 @@ array norm(
|
|||||||
return matrix_norm(a, ord, ax, keepdims, s);
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[linalg::norm] Received too many axes for norm");
|
"[linalg::norm] Received too many axes for norm.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
17
mlx/linalg.h
17
mlx/linalg.h
@ -4,21 +4,22 @@
|
|||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include "array.h"
|
#include "mlx/array.h"
|
||||||
#include "device.h"
|
#include "mlx/device.h"
|
||||||
#include "ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
/*
|
/**
|
||||||
* Compute vector or matrix norms.
|
* Compute vector or matrix norms.
|
||||||
*
|
*
|
||||||
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
||||||
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
||||||
* - If axis is provided, but ord is not, then the 2-norm is computed along the
|
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
|
||||||
* given axes. At most 2 axes can be specified.
|
* for matrices) is computed along the given axes. At most 2 axes can be
|
||||||
* - If both axis and ord are provided, then the corresponding matrix of vector
|
* specified.
|
||||||
|
* - If both axis and ord are provided, then the corresponding matrix or vector
|
||||||
* norm is computed. At most 2 axes can be specified.
|
* norm is computed. At most 2 axes can be specified.
|
||||||
*/
|
*/
|
||||||
array norm(
|
array norm(
|
||||||
|
@ -20,8 +20,8 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
py::options options;
|
py::options options;
|
||||||
options.disable_function_signatures();
|
options.disable_function_signatures();
|
||||||
|
|
||||||
auto m =
|
auto m = parent_module.def_submodule(
|
||||||
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
|
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"norm",
|
"norm",
|
||||||
@ -72,8 +72,8 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||||
2-norm of ``a.flatten`` will be returned.
|
2-norm of ``a.flatten`` will be returned.
|
||||||
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||||
If ``None``, the 2-norm will be computed along the given ``axis``.
|
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
|
||||||
Default: ``None``.
|
along the given ``axis``. Default: ``None``.
|
||||||
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||||
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
|
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
|
||||||
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
|
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
|
||||||
|
@ -11,74 +11,56 @@ import numpy as np
|
|||||||
|
|
||||||
class TestLinalg(mlx_tests.MLXTestCase):
|
class TestLinalg(mlx_tests.MLXTestCase):
|
||||||
def test_norm(self):
|
def test_norm(self):
|
||||||
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")]
|
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||||
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||||
|
|
||||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
# Test when at least one axis is provided
|
# Test when at least one axis is provided
|
||||||
for num_axes in range(1, len(shape)):
|
for num_axes in range(1, len(shape)):
|
||||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
if num_axes == 1:
|
||||||
if num_axes == 1:
|
ords = vector_ords
|
||||||
ords = vector_ords
|
else:
|
||||||
else:
|
ords = matrix_ords
|
||||||
ords = matrix_ords
|
|
||||||
for keepdims in [True, False]:
|
|
||||||
# Test axis provided, no ord provided
|
|
||||||
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
|
|
||||||
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
|
|
||||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
|
||||||
# Test both ord and axis provided
|
|
||||||
for o in ords:
|
|
||||||
for keepdims in [True, False]:
|
|
||||||
if o:
|
|
||||||
out_np = np.linalg.norm(
|
|
||||||
x_np, ord=o, axis=axis, keepdims=keepdims
|
|
||||||
)
|
|
||||||
out_mx = mx.linalg.norm(
|
|
||||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out_np = np.linalg.norm(
|
|
||||||
x_np, axis=axis, keepdims=keepdims
|
|
||||||
)
|
|
||||||
out_mx = mx.linalg.norm(
|
|
||||||
x_mx, axis=axis, keepdims=keepdims
|
|
||||||
)
|
|
||||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
|
||||||
|
|
||||||
# Test only axis provided
|
|
||||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
|
||||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
|
||||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
|
||||||
|
|
||||||
for num_axes in range(1, len(shape)):
|
|
||||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||||
for keepdims in [True, False]:
|
for keepdims in [True, False]:
|
||||||
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
|
for o in ords:
|
||||||
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
|
out_np = np.linalg.norm(
|
||||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
out_mx = mx.linalg.norm(
|
||||||
|
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
with self.subTest(
|
||||||
|
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
# Test only ord provided
|
# Test only ord provided
|
||||||
for shape in [(3,), (2, 3)]:
|
for shape in [(3,), (2, 3)]:
|
||||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
||||||
for keepdims in [True, False]:
|
for keepdims in [True, False]:
|
||||||
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
|
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
|
||||||
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
||||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
# Test no ord and no axis provided
|
# Test no ord and no axis provided
|
||||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
x_np = np.arange(math.prod(shape)).reshape(shape)
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
for keepdims in [True, False]:
|
||||||
for keepdims in [True, False]:
|
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||||
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||||
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
with self.subTest(shape=shape, keepdims=keepdims):
|
||||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user