mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
@@ -3180,9 +3180,9 @@ array conv_general(
|
||||
bool flip /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Run checks
|
||||
if (groups != 1 && in.ndim() != 3) {
|
||||
if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) {
|
||||
throw std::invalid_argument(
|
||||
"[conv] Can only handle groups != 1 in 1D convolutions.");
|
||||
"[conv] Can only handle groups != 1 in 1D or 2D convolutions.");
|
||||
}
|
||||
|
||||
int spatial_dims = in.ndim() - 2;
|
||||
|
||||
Reference in New Issue
Block a user