fix python tests

This commit is contained in:
Awni Hannun 2023-12-26 14:47:56 -08:00
parent 49e3e99da3
commit 67e319488c
4 changed files with 65 additions and 73 deletions

View File

@ -4,9 +4,7 @@
#include <ostream> #include <ostream>
#include <vector> #include <vector>
#include "mlx/array.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/ops.h"
namespace mlx::core::linalg { namespace mlx::core::linalg {
@ -48,25 +46,36 @@ inline array matrix_norm(
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto row_axis = axis[0]; auto row_axis = axis[0];
auto col_axis = axis[1]; auto col_axis = axis[1];
if (!keepdims && col_axis > row_axis && col_axis > 0) {
col_axis -= 1;
}
if (ord == -1.0) { if (ord == -1.0) {
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
return astype( return astype(
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype, dtype,
s); s);
} else if (ord == 1.0) { } else if (ord == 1.0) {
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
return astype( return astype(
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype, dtype,
s); s);
} else if (ord == std::numeric_limits<double>::infinity()) {
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
return astype(
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
dtype,
s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
return astype(
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
dtype,
s);
} else if (ord == 2.0 || ord == -2.0) { } else if (ord == 2.0 || ord == -2.0) {
throw std::runtime_error( throw std::runtime_error(
"[linalg::norm] Singular value norms are not implemented."); "[linalg::norm] Singular value norms are not implemented.");
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm"; msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@ -78,13 +87,13 @@ inline array matrix_norm(
bool keepdims, bool keepdims,
StreamOrDevice s) { StreamOrDevice s) {
if (ord == "f" || ord == "fro") { if (ord == "f" || ord == "fro") {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); return sqrt(sum(square(a, s), axis, keepdims, s), s);
} else if (ord == "nuc") { } else if (ord == "nuc") {
throw std::runtime_error( throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented."); "[linalg::norm] Nuclear norm not yet implemented.");
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm"; msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@ -100,7 +109,7 @@ array norm(
if (axis.value().size() > 2) { if (axis.value().size() > 2) {
throw std::invalid_argument( throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm"); "[linalg::norm] Received too many axes for norm.");
} }
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s); return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
} }
@ -124,7 +133,7 @@ array norm(
return matrix_norm(a, ord, ax, keepdims, s); return matrix_norm(a, ord, ax, keepdims, s);
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm"); "[linalg::norm] Received too many axes for norm.");
} }
} }

View File

@ -4,21 +4,22 @@
#include <optional> #include <optional>
#include "array.h" #include "mlx/array.h"
#include "device.h" #include "mlx/device.h"
#include "ops.h" #include "mlx/ops.h"
#include "stream.h" #include "mlx/stream.h"
namespace mlx::core::linalg { namespace mlx::core::linalg {
/* /**
* Compute vector or matrix norms. * Compute vector or matrix norms.
* *
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x). * - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
* - If axis is not provided but ord is, then x must be either 1D or 2D. * - If axis is not provided but ord is, then x must be either 1D or 2D.
* - If axis is provided, but ord is not, then the 2-norm is computed along the * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
* given axes. At most 2 axes can be specified. * for matrices) is computed along the given axes. At most 2 axes can be
* - If both axis and ord are provided, then the corresponding matrix of vector * specified.
* - If both axis and ord are provided, then the corresponding matrix or vector
* norm is computed. At most 2 axes can be specified. * norm is computed. At most 2 axes can be specified.
*/ */
array norm( array norm(

View File

@ -20,8 +20,8 @@ void init_linalg(py::module_& parent_module) {
py::options options; py::options options;
options.disable_function_signatures(); options.disable_function_signatures();
auto m = auto m = parent_module.def_submodule(
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); "linalg", "mlx.core.linalg: linear algebra routines.");
m.def( m.def(
"norm", "norm",
@ -72,8 +72,8 @@ void init_linalg(py::module_& parent_module) {
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
2-norm of ``a.flatten`` will be returned. 2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``). ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm will be computed along the given ``axis``. If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
Default: ``None``. along the given ``axis``. Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
axis of ``a`` along which to compute the vector norms. If ``axis`` is a axis of ``a`` along which to compute the vector norms. If ``axis`` is a
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix

View File

@ -11,74 +11,56 @@ import numpy as np
class TestLinalg(mlx_tests.MLXTestCase): class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self): def test_norm(self):
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")] vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
for shape in [(3,), (2, 3), (2, 3, 3)]: for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(math.prod(shape)).reshape(shape) x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(math.prod(shape)).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
# Test when at least one axis is provided # Test when at least one axis is provided
for num_axes in range(1, len(shape)): for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes): if num_axes == 1:
if num_axes == 1: ords = vector_ords
ords = vector_ords else:
else: ords = matrix_ords
ords = matrix_ords
for keepdims in [True, False]:
# Test axis provided, no ord provided
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
# Test both ord and axis provided
for o in ords:
for keepdims in [True, False]:
if o:
out_np = np.linalg.norm(
x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
)
else:
out_np = np.linalg.norm(
x_np, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, axis=axis, keepdims=keepdims
)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
# Test only axis provided
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(math.prod(shape)).reshape(shape)
x_np = np.arange(math.prod(shape)).reshape(shape)
for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes): for axis in itertools.combinations(range(len(shape)), num_axes):
for keepdims in [True, False]: for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) for o in ords:
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) out_np = np.linalg.norm(
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
)
with self.subTest(
shape=shape, ord=o, axis=axis, keepdims=keepdims
):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test only ord provided # Test only ord provided
for shape in [(3,), (2, 3)]: for shape in [(3,), (2, 3)]:
x_mx = mx.arange(math.prod(shape)).reshape(shape) x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(math.prod(shape)).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for o in [None, 1, -1, float("inf"), -float("inf")]: for o in [None, 1, -1, float("inf"), -float("inf")]:
for keepdims in [True, False]: for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) with self.subTest(shape=shape, ord=o, keepdims=keepdims):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test no ord and no axis provided # Test no ord and no axis provided
for shape in [(3,), (2, 3), (2, 3, 3)]: for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(math.prod(shape)).reshape(shape) x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(math.prod(shape)).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for o in [None, 1, -1, float("inf"), -float("inf")]: for keepdims in [True, False]:
for keepdims in [True, False]: out_np = np.linalg.norm(x_np, keepdims=keepdims)
out_np = np.linalg.norm(x_np, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) with self.subTest(shape=shape, keepdims=keepdims):
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
if __name__ == "__main__": if __name__ == "__main__":