Add groups to 2-D convolutions (#1129)

* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
This commit is contained in:
Rifur13
2024-05-22 23:01:44 -04:00
committed by GitHub
parent eb8321d863
commit 9401507336
9 changed files with 322 additions and 132 deletions

View File

@@ -3268,22 +3268,22 @@ TEST_CASE("test conv1d") {
float16);
auto expected = array(
{1.5685,
0.5672,
1.8121,
1.2948,
2.3448,
1.6104,
2.7743,
1.6126,
1.4056,
0.9331,
1.8739,
1.0909},
{1.56836,
0.567383,
1.8125,
1.29492,
2.34375,
1.61035,
2.77539,
1.61328,
1.40527,
0.933105,
1.87402,
1.09082},
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
CHECK(allclose(out, expected).item<bool>());
}
{
@@ -3309,22 +3309,151 @@ TEST_CASE("test conv1d") {
{4, 3, 1});
auto expected = array(
{1.0703,
0.7533,
0.7007,
0.4681,
1.1859,
0.9117,
0.9565,
0.6111,
0.6416,
0.5665,
0.9074,
0.0605},
{1.07007,
0.753201,
0.700818,
0.468176,
1.18568,
0.91152,
0.956607,
0.611213,
0.641404,
0.566401,
0.907472,
0.0605397},
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
CHECK(allclose(out, expected).item<bool>());
}
}
TEST_CASE("test conv2d") {
auto in = array(
{0.57429284,
-0.21628855,
-0.18673691,
-0.3793517,
0.3059678,
-0.8137168,
0.6168841,
-0.26912728},
{1, 2, 2, 2});
std::pair<int, int> kernel{2, 2};
std::pair<int, int> stride{1, 1};
std::pair<int, int> padding{0, 0};
{
int groups = 1;
auto wt = array(
{0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172,
-0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584,
0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907,
0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944,
-0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727,
-0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157,
1.6598022, 0.74204415},
{4, 2, 2, 2});
auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
CHECK(allclose(out, expected).item<bool>());
}
{
int groups = 2;
auto wt = array(
{0.3190391,
-0.24937038,
1.46210794,
-2.06014071,
-0.3224172,
-0.38405435,
1.13376944,
-1.09989127,
-0.17242821,
-0.87785842,
0.04221375,
0.58281521,
-1.10061918,
1.14472371,
0.90159072,
0.50249434},
{4, 2, 2, 1});
auto expected = array(
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
CHECK(allclose(out, expected).item<bool>());
}
{
in = array(
{0.57429284,
-0.21628855,
-0.18673691,
-0.3793517,
0.3059678,
-0.8137168,
0.6168841,
-0.26912728,
0.57429284,
-0.21628855,
-0.18673691,
-0.3793517,
0.3059678,
-0.8137168,
0.6168841,
-0.26912728},
{2, 2, 2, 2});
int groups = 2;
auto wt = array(
{0.3190391,
-0.24937038,
1.46210794,
-2.06014071,
-0.3224172,
-0.38405435,
1.13376944,
-1.09989127,
-0.17242821,
-0.87785842,
0.04221375,
0.58281521,
-1.10061918,
1.14472371,
0.90159072,
0.50249434},
{4, 2, 2, 1});
auto expected = array(
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
CHECK(allclose(out, expected).item<bool>());
}
}