fixed a bug with no ord and axis provided

This commit is contained in:
Gabrijel Boduljak 2023-12-22 12:00:18 +01:00 committed by Awni Hannun
parent 5a184d5b5d
commit bbfe042a2b
2 changed files with 39 additions and 4 deletions

View File

@ -44,7 +44,9 @@ void init_linalg(py::module_& parent_module) {
return norm( return norm(
a, a,
"inf", "inf",
get_reduce_axes(axis, a.ndim()), std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims, keepdims,
stream); stream);
} }
@ -56,15 +58,32 @@ void init_linalg(py::module_& parent_module) {
stream); stream);
} }
return norm( return norm(
a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); a,
p,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
}, },
[&](const std::string& p) { [&](const std::string& p) {
return norm( return norm(
a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); a,
p,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
}, },
[&](const std::monostate _) { [&](const std::monostate _) {
return norm( return norm(
a, get_reduce_axes(axis, a.ndim()), keepdims, stream); a,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
}}, }},
ord); ord);
}, },

View File

@ -24,6 +24,12 @@ class TestLinalg(mlx_tests.MLXTestCase):
ords = vector_ords ords = vector_ords
else: else:
ords = matrix_ords ords = matrix_ords
for keepdims in [True, False]:
# Test axis provided, no ord provided
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
# Test both ord and axis provided
for o in ords: for o in ords:
for keepdims in [True, False]: for keepdims in [True, False]:
if o: if o:
@ -64,6 +70,16 @@ class TestLinalg(mlx_tests.MLXTestCase):
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
# Test no ord and no axis provided
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(math.prod(shape)).reshape(shape)
x_np = np.arange(math.prod(shape)).reshape(shape)
for o in [None, 1, -1, float("inf"), -float("inf")]:
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()