Add groups to 2-D convolutions (#1129)

* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
This commit is contained in:
Rifur13
2024-05-22 23:01:44 -04:00
committed by GitHub
parent eb8321d863
commit 9401507336
9 changed files with 322 additions and 132 deletions

View File

@@ -3180,9 +3180,9 @@ array conv_general(
bool flip /* = false */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1 && in.ndim() != 3) {
if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) {
throw std::invalid_argument(
"[conv] Can only handle groups != 1 in 1D convolutions.");
"[conv] Can only handle groups != 1 in 1D or 2D convolutions.");
}
int spatial_dims = in.ndim() - 2;