mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
fixed a bug with no ord and axis provided
This commit is contained in:
parent
5a184d5b5d
commit
bbfe042a2b
@ -44,7 +44,9 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
return norm(
|
return norm(
|
||||||
a,
|
a,
|
||||||
"inf",
|
"inf",
|
||||||
get_reduce_axes(axis, a.ndim()),
|
std::holds_alternative<std::monostate>(axis)
|
||||||
|
? std::vector<int>()
|
||||||
|
: get_reduce_axes(axis, a.ndim()),
|
||||||
keepdims,
|
keepdims,
|
||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
@ -56,15 +58,32 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
return norm(
|
return norm(
|
||||||
a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream);
|
a,
|
||||||
|
p,
|
||||||
|
std::holds_alternative<std::monostate>(axis)
|
||||||
|
? std::vector<int>()
|
||||||
|
: get_reduce_axes(axis, a.ndim()),
|
||||||
|
keepdims,
|
||||||
|
stream);
|
||||||
},
|
},
|
||||||
[&](const std::string& p) {
|
[&](const std::string& p) {
|
||||||
return norm(
|
return norm(
|
||||||
a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream);
|
a,
|
||||||
|
p,
|
||||||
|
std::holds_alternative<std::monostate>(axis)
|
||||||
|
? std::vector<int>()
|
||||||
|
: get_reduce_axes(axis, a.ndim()),
|
||||||
|
keepdims,
|
||||||
|
stream);
|
||||||
},
|
},
|
||||||
[&](const std::monostate _) {
|
[&](const std::monostate _) {
|
||||||
return norm(
|
return norm(
|
||||||
a, get_reduce_axes(axis, a.ndim()), keepdims, stream);
|
a,
|
||||||
|
std::holds_alternative<std::monostate>(axis)
|
||||||
|
? std::vector<int>()
|
||||||
|
: get_reduce_axes(axis, a.ndim()),
|
||||||
|
keepdims,
|
||||||
|
stream);
|
||||||
}},
|
}},
|
||||||
ord);
|
ord);
|
||||||
},
|
},
|
||||||
|
@ -24,6 +24,12 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
ords = vector_ords
|
ords = vector_ords
|
||||||
else:
|
else:
|
||||||
ords = matrix_ords
|
ords = matrix_ords
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
# Test axis provided, no ord provided
|
||||||
|
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
|
||||||
|
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
# Test both ord and axis provided
|
||||||
for o in ords:
|
for o in ords:
|
||||||
for keepdims in [True, False]:
|
for keepdims in [True, False]:
|
||||||
if o:
|
if o:
|
||||||
@ -64,6 +70,16 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
||||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
|
||||||
|
# Test no ord and no axis provided
|
||||||
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
|
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
||||||
|
x_np = np.arange(math.prod(shape)).reshape(shape)
|
||||||
|
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||||
|
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user