Add groups to Conv1d (#948)

* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Rifur13
2024-04-27 09:24:57 -04:00
committed by GitHub
parent 86f495985b
commit c4a471c99d
11 changed files with 633 additions and 55 deletions

View File

@@ -3228,3 +3228,102 @@ TEST_CASE("test meshgrid") {
CHECK(array_equal(out[0], expected_zero).item<bool>());
CHECK(array_equal(out[1], expected_one).item<bool>());
}
TEST_CASE("test conv1d") {
auto in = astype(
array(
{0.5488135,
0.71518937,
0.60276338,
0.54488318,
0.4236548,
0.64589411},
{1, 3, 2}),
float16);
int kernel = 3;
int stride = 1;
int padding = 1;
{
int groups = 1;
auto wt = astype(
array(
{
0.43758721, 0.891773, 0.96366276, 0.38344152,
0.79172504, 0.52889492,
0.56804456, 0.92559664, 0.07103606, 0.0871293,
0.0202184, 0.83261985,
0.77815675, 0.87001215, 0.97861834, 0.79915856,
0.46147936, 0.78052918,
0.11827443, 0.63992102, 0.14335329, 0.94466892,
0.52184832, 0.41466194
},
{4, 3, 2}),
float16);
auto expected = array(
{1.5685,
0.5672,
1.8121,
1.2948,
2.3448,
1.6104,
2.7743,
1.6126,
1.4056,
0.9331,
1.8739,
1.0909},
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
}
{
int groups = 2;
auto wt = array(
{0.43758721,
0.891773,
0.96366276,
0.38344152,
0.79172504,
0.52889492,
0.56804456,
0.92559664,
0.07103606,
0.0871293,
0.0202184,
0.83261985
},
{4, 3, 1});
auto expected = array(
{1.0703,
0.7533,
0.7007,
0.4681,
1.1859,
0.9117,
0.9565,
0.6111,
0.6416,
0.5665,
0.9074,
0.0605},
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
}
}