mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU * Add GPU support * Parallelize inside metal kernel * clenaup * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * New unfold kernel + remove unused code * Remove copy and refactor * Update vjp and reuse steel gemm * Fixed groups on cpu * Fix metal validation --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -3228,3 +3228,102 @@ TEST_CASE("test meshgrid") {
|
||||
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test conv1d") {
|
||||
auto in = astype(
|
||||
array(
|
||||
{0.5488135,
|
||||
0.71518937,
|
||||
0.60276338,
|
||||
0.54488318,
|
||||
0.4236548,
|
||||
0.64589411},
|
||||
{1, 3, 2}),
|
||||
float16);
|
||||
|
||||
int kernel = 3;
|
||||
int stride = 1;
|
||||
int padding = 1;
|
||||
|
||||
{
|
||||
int groups = 1;
|
||||
auto wt = astype(
|
||||
array(
|
||||
{
|
||||
|
||||
0.43758721, 0.891773, 0.96366276, 0.38344152,
|
||||
0.79172504, 0.52889492,
|
||||
|
||||
0.56804456, 0.92559664, 0.07103606, 0.0871293,
|
||||
0.0202184, 0.83261985,
|
||||
|
||||
0.77815675, 0.87001215, 0.97861834, 0.79915856,
|
||||
0.46147936, 0.78052918,
|
||||
|
||||
0.11827443, 0.63992102, 0.14335329, 0.94466892,
|
||||
0.52184832, 0.41466194
|
||||
|
||||
},
|
||||
{4, 3, 2}),
|
||||
float16);
|
||||
|
||||
auto expected = array(
|
||||
{1.5685,
|
||||
0.5672,
|
||||
1.8121,
|
||||
1.2948,
|
||||
2.3448,
|
||||
1.6104,
|
||||
2.7743,
|
||||
1.6126,
|
||||
1.4056,
|
||||
0.9331,
|
||||
1.8739,
|
||||
1.0909},
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
int groups = 2;
|
||||
auto wt = array(
|
||||
{0.43758721,
|
||||
0.891773,
|
||||
0.96366276,
|
||||
|
||||
0.38344152,
|
||||
0.79172504,
|
||||
0.52889492,
|
||||
|
||||
0.56804456,
|
||||
0.92559664,
|
||||
0.07103606,
|
||||
|
||||
0.0871293,
|
||||
0.0202184,
|
||||
0.83261985
|
||||
|
||||
},
|
||||
{4, 3, 1});
|
||||
|
||||
auto expected = array(
|
||||
{1.0703,
|
||||
0.7533,
|
||||
0.7007,
|
||||
0.4681,
|
||||
1.1859,
|
||||
0.9117,
|
||||
0.9565,
|
||||
0.6111,
|
||||
0.6416,
|
||||
0.5665,
|
||||
0.9074,
|
||||
0.0605},
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user