diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 0b6b31801..9841f03bf 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -16,33 +16,34 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { array arr_three_d = reshape(arange(18), {2, 3, 3}); CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item()); - CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item()); + CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9))) + .item()); CHECK(array_equal( - norm(arr_two_d), + norm(arr_two_d, {}, false), array(sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - norm(arr_two_d, {0}), + norm(arr_two_d, {0}, false), array( {sqrt(0 + 3 * 3 + 6 * 6), sqrt(1 + 4 * 4 + 7 * 7), sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); CHECK(array_equal( - norm(arr_two_d, {1}), + norm(arr_two_d, {1}, false), array( {sqrt(0 + 1 + 2 * 2), sqrt(3 * 3 + 4 * 4 + 5 * 5), sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); CHECK(array_equal( - norm(arr_two_d, {0, 1}), + norm(arr_two_d, {0, 1}, false), array(sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - norm(arr_three_d, {2}), + norm(arr_three_d, {2}, false), array( { sqrt(0 + 1 + 2 * 2), @@ -55,7 +56,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { {2, 3})) .item()); CHECK(array_equal( - norm(arr_three_d, {1}), + norm(arr_three_d, {1}, false), array( { sqrt(0 + 3 * 3 + 6 * 6), @@ -68,7 +69,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { {2, 3})) .item()); CHECK(array_equal( - norm(arr_three_d, {0}), + norm(arr_three_d, {0}, false), array( { sqrt(0 + 9 * 9), @@ -84,7 +85,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { {3, 3})) .item()); CHECK(array_equal( - norm(arr_three_d, {1, 2}), + norm(arr_three_d, {1, 2}, false), array( {sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + @@ -94,4 +95,212 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { 15 * 15 + 16 * 16 + 17 * 17)}, {2})) .item()); +} + +TEST_CASE("[mlx.core.linalg.norm] double ord") { + array arr_one_d({1, 2, 3}); + array arr_two_d = reshape(arange(9), {3, 3}); + array arr_three_d = reshape(arange(18), {2, 3, 3}); + + CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item()); + CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item()); + CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item()); + + CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9))) + .item()); + CHECK(array_equal( + norm(arr_two_d, 2.0, {0}, false), + array( + {sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + .item()); + CHECK(array_equal( + norm(arr_two_d, 2.0, {1}, false), + array( + {sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + .item()); + CHECK(array_equal( + norm(arr_three_d, 2.0, {2}, false), + array( + { + sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8), + sqrt(9 * 9 + 10 * 10 + 11 * 11), + sqrt(12 * 12 + 13 * 13 + 14 * 14), + sqrt(15 * 15 + 16 * 16 + 17 * 17), + }, + {2, 3})) + .item()); + CHECK(array_equal( + norm(arr_three_d, 2.0, {1}, false), + array( + { + sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8), + sqrt(9 * 9 + 12 * 12 + 15 * 15), + sqrt(10 * 10 + 13 * 13 + 16 * 16), + sqrt(11 * 11 + 14 * 14 + 17 * 17), + }, + {2, 3})) + .item()); + CHECK(array_equal( + norm(arr_three_d, 2.0, {0}, false), + array( + { + sqrt(0 + 9 * 9), + sqrt(1 + 10 * 10), + sqrt(2 * 2 + 11 * 11), + sqrt(3 * 3 + 12 * 12), + sqrt(4 * 4 + 13 * 13), + sqrt(5 * 5 + 14 * 14), + sqrt(6 * 6 + 15 * 15), + sqrt(7 * 7 + 16 * 16), + sqrt(8 * 8 + 17 * 17), + }, + {3, 3})) + .item()); + + CHECK(allclose( + norm(arr_three_d, 3.0, {0}), + array( + {9., + 10.00333222, + 11.02199456, + 12.06217728, + 13.12502645, + 14.2094363, + 15.31340617, + 16.43469751, + 17.57113899}, + {3, 3})) + .item()); + CHECK( + allclose( + norm(arr_three_d, 3.0, {1}), + array( + {6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893}, + {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 3.0, {2}), + array( + {2.08008382, + 6., + 10.23127655, + 14.5180117, + 18.82291607, + 23.13593104}, + {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 0.0, {0}), + array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + .item()); + CHECK( + allclose( + norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK( + allclose( + norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 1.0, {0}), + array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 1.0, {1}), + array({9., 12., 15., 36., 39., 42.}, {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 1.0, {2}), + array({3., 12., 21., 30., 39., 48.}, {2, 3})) + .item()); + + CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item()); + CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item()); + CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item()); + CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item()); + + CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1})) + .item()); + CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1})) + .item()); + CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1})) + .item()); + CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1})) + .item()); + + CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0)) + .item()); + CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0)) + .item()); + // + CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.})) + .item()); + CHECK( + allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item()); + CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.})) + .item()); + CHECK( + allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item()); + CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.})) + .item()); + CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item()); + CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item()); +} + +TEST_CASE("[mlx.core.linalg.norm] string ord") { + array arr_one_d({1, 2, 3}); + array arr_two_d = reshape(arange(9), {3, 3}); + array arr_three_d = reshape(arange(18), {2, 3, 3}); + + CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item()); + CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item()); + + CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857})) + .item()); + CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857})) + .item()); + CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item()); + CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item()); + + CHECK(allclose( + norm(arr_three_d, "fro", {0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose( + norm(arr_three_d, "f", {0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(arr_three_d, "f", {1, 0}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK( + allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907})) + .item()); + CHECK( + allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.})) + .item()); + CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.})) + .item()); } \ No newline at end of file