mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
@@ -3268,22 +3268,22 @@ TEST_CASE("test conv1d") {
|
||||
float16);
|
||||
|
||||
auto expected = array(
|
||||
{1.5685,
|
||||
0.5672,
|
||||
1.8121,
|
||||
1.2948,
|
||||
2.3448,
|
||||
1.6104,
|
||||
2.7743,
|
||||
1.6126,
|
||||
1.4056,
|
||||
0.9331,
|
||||
1.8739,
|
||||
1.0909},
|
||||
{1.56836,
|
||||
0.567383,
|
||||
1.8125,
|
||||
1.29492,
|
||||
2.34375,
|
||||
1.61035,
|
||||
2.77539,
|
||||
1.61328,
|
||||
1.40527,
|
||||
0.933105,
|
||||
1.87402,
|
||||
1.09082},
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
@@ -3309,22 +3309,151 @@ TEST_CASE("test conv1d") {
|
||||
{4, 3, 1});
|
||||
|
||||
auto expected = array(
|
||||
{1.0703,
|
||||
0.7533,
|
||||
0.7007,
|
||||
0.4681,
|
||||
1.1859,
|
||||
0.9117,
|
||||
0.9565,
|
||||
0.6111,
|
||||
0.6416,
|
||||
0.5665,
|
||||
0.9074,
|
||||
0.0605},
|
||||
{1.07007,
|
||||
0.753201,
|
||||
0.700818,
|
||||
0.468176,
|
||||
1.18568,
|
||||
0.91152,
|
||||
0.956607,
|
||||
0.611213,
|
||||
0.641404,
|
||||
0.566401,
|
||||
0.907472,
|
||||
0.0605397},
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test conv2d") {
|
||||
auto in = array(
|
||||
{0.57429284,
|
||||
-0.21628855,
|
||||
-0.18673691,
|
||||
-0.3793517,
|
||||
|
||||
0.3059678,
|
||||
-0.8137168,
|
||||
0.6168841,
|
||||
-0.26912728},
|
||||
{1, 2, 2, 2});
|
||||
|
||||
std::pair<int, int> kernel{2, 2};
|
||||
std::pair<int, int> stride{1, 1};
|
||||
std::pair<int, int> padding{0, 0};
|
||||
|
||||
{
|
||||
int groups = 1;
|
||||
|
||||
auto wt = array(
|
||||
{0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172,
|
||||
-0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584,
|
||||
0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907,
|
||||
0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944,
|
||||
-0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727,
|
||||
-0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157,
|
||||
1.6598022, 0.74204415},
|
||||
{4, 2, 2, 2});
|
||||
|
||||
auto expected =
|
||||
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
int groups = 2;
|
||||
auto wt = array(
|
||||
{0.3190391,
|
||||
-0.24937038,
|
||||
|
||||
1.46210794,
|
||||
-2.06014071,
|
||||
|
||||
-0.3224172,
|
||||
-0.38405435,
|
||||
|
||||
1.13376944,
|
||||
-1.09989127,
|
||||
|
||||
-0.17242821,
|
||||
-0.87785842,
|
||||
|
||||
0.04221375,
|
||||
0.58281521,
|
||||
|
||||
-1.10061918,
|
||||
1.14472371,
|
||||
|
||||
0.90159072,
|
||||
0.50249434},
|
||||
{4, 2, 2, 1});
|
||||
|
||||
auto expected = array(
|
||||
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
|
||||
|
||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
in = array(
|
||||
{0.57429284,
|
||||
-0.21628855,
|
||||
-0.18673691,
|
||||
-0.3793517,
|
||||
|
||||
0.3059678,
|
||||
-0.8137168,
|
||||
0.6168841,
|
||||
-0.26912728,
|
||||
|
||||
0.57429284,
|
||||
-0.21628855,
|
||||
-0.18673691,
|
||||
-0.3793517,
|
||||
|
||||
0.3059678,
|
||||
-0.8137168,
|
||||
0.6168841,
|
||||
-0.26912728},
|
||||
{2, 2, 2, 2});
|
||||
|
||||
int groups = 2;
|
||||
auto wt = array(
|
||||
{0.3190391,
|
||||
-0.24937038,
|
||||
|
||||
1.46210794,
|
||||
-2.06014071,
|
||||
|
||||
-0.3224172,
|
||||
-0.38405435,
|
||||
|
||||
1.13376944,
|
||||
-1.09989127,
|
||||
|
||||
-0.17242821,
|
||||
-0.87785842,
|
||||
|
||||
0.04221375,
|
||||
0.58281521,
|
||||
|
||||
-1.10061918,
|
||||
1.14472371,
|
||||
|
||||
0.90159072,
|
||||
0.50249434},
|
||||
{4, 2, 2, 1});
|
||||
|
||||
auto expected = array(
|
||||
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
|
||||
|
||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user