added tests

This commit is contained in:
Gabrijel Boduljak 2023-12-21 18:34:02 +01:00 committed by Awni Hannun
parent 8c43d820d9
commit 5d7a06717c

View File

@ -16,33 +16,34 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
array arr_three_d = reshape(arange(18), {2, 3, 3}); array arr_three_d = reshape(arange(18), {2, 3, 3});
CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>()); CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>());
CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item<bool>()); CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9)))
.item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_two_d), norm(arr_two_d, {}, false),
array(sqrt( array(sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_two_d, {0}), norm(arr_two_d, {0}, false),
array( array(
{sqrt(0 + 3 * 3 + 6 * 6), {sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7), sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8)})) sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_two_d, {1}), norm(arr_two_d, {1}, false),
array( array(
{sqrt(0 + 1 + 2 * 2), {sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5), sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)})) sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_two_d, {0, 1}), norm(arr_two_d, {0, 1}, false),
array(sqrt( array(sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_three_d, {2}), norm(arr_three_d, {2}, false),
array( array(
{ {
sqrt(0 + 1 + 2 * 2), sqrt(0 + 1 + 2 * 2),
@ -55,7 +56,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
{2, 3})) {2, 3}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_three_d, {1}), norm(arr_three_d, {1}, false),
array( array(
{ {
sqrt(0 + 3 * 3 + 6 * 6), sqrt(0 + 3 * 3 + 6 * 6),
@ -68,7 +69,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
{2, 3})) {2, 3}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_three_d, {0}), norm(arr_three_d, {0}, false),
array( array(
{ {
sqrt(0 + 9 * 9), sqrt(0 + 9 * 9),
@ -84,7 +85,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
{3, 3})) {3, 3}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
norm(arr_three_d, {1, 2}), norm(arr_three_d, {1, 2}, false),
array( array(
{sqrt( {sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
@ -94,4 +95,212 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
15 * 15 + 16 * 16 + 17 * 17)}, 15 * 15 + 16 * 16 + 17 * 17)},
{2})) {2}))
.item<bool>()); .item<bool>());
}
TEST_CASE("[mlx.core.linalg.norm] double ord") {
array arr_one_d({1, 2, 3});
array arr_two_d = reshape(arange(9), {3, 3});
array arr_three_d = reshape(arange(18), {2, 3, 3});
CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item<bool>());
CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item<bool>());
CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item<bool>());
CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9)))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, 2.0, {0}, false),
array(
{sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, 2.0, {1}, false),
array(
{sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {2}, false),
array(
{
sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8),
sqrt(9 * 9 + 10 * 10 + 11 * 11),
sqrt(12 * 12 + 13 * 13 + 14 * 14),
sqrt(15 * 15 + 16 * 16 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {1}, false),
array(
{
sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8),
sqrt(9 * 9 + 12 * 12 + 15 * 15),
sqrt(10 * 10 + 13 * 13 + 16 * 16),
sqrt(11 * 11 + 14 * 14 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {0}, false),
array(
{
sqrt(0 + 9 * 9),
sqrt(1 + 10 * 10),
sqrt(2 * 2 + 11 * 11),
sqrt(3 * 3 + 12 * 12),
sqrt(4 * 4 + 13 * 13),
sqrt(5 * 5 + 14 * 14),
sqrt(6 * 6 + 15 * 15),
sqrt(7 * 7 + 16 * 16),
sqrt(8 * 8 + 17 * 17),
},
{3, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 3.0, {0}),
array(
{9.,
10.00333222,
11.02199456,
12.06217728,
13.12502645,
14.2094363,
15.31340617,
16.43469751,
17.57113899},
{3, 3}))
.item<bool>());
CHECK(
allclose(
norm(arr_three_d, 3.0, {1}),
array(
{6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893},
{2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 3.0, {2}),
array(
{2.08008382,
6.,
10.23127655,
14.5180117,
18.82291607,
23.13593104},
{2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 0.0, {0}),
array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
.item<bool>());
CHECK(
allclose(
norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(
allclose(
norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 1.0, {0}),
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 1.0, {1}),
array({9., 12., 15., 36., 39., 42.}, {2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 1.0, {2}),
array({3., 12., 21., 30., 39., 48.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1}))
.item<bool>());
CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0))
.item<bool>());
CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0))
.item<bool>());
//
CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item<bool>());
}
TEST_CASE("[mlx.core.linalg.norm] string ord") {
array arr_one_d({1, 2, 3});
array arr_two_d = reshape(arange(9), {3, 3});
array arr_three_d = reshape(arange(18), {2, 3, 3});
CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item<bool>());
CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item<bool>());
CHECK(allclose(
norm(arr_three_d, "fro", {0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, "f", {0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, "f", {1, 0}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.}))
.item<bool>());
} }