Add groups to Conv1d (#948)

* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Rifur13
2024-04-27 09:24:57 -04:00
committed by GitHub
parent 86f495985b
commit c4a471c99d
11 changed files with 633 additions and 55 deletions

View File

@@ -320,7 +320,7 @@ array reshape(
"[reshape] Cannot infer the shape of an empty array");
}
// Check the the reshaping is valid
// Check that the reshaping is valid
if (a.size() != size) {
std::ostringstream msg;
msg << "[reshape] Cannot reshape array of size " << a.size()
@@ -2947,7 +2947,8 @@ inline std::vector<int> conv_out_shape(
return out_shape;
}
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
inline void
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
if (!issubdtype(in.dtype(), floating)) {
std::ostringstream msg;
msg << "[conv] Invalid input array with type " << in.dtype() << "."
@@ -2972,11 +2973,35 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
throw std::invalid_argument(msg.str());
}
if (in.shape(n_dim + 1) != wt.shape(n_dim + 1)) {
if (in.shape(n_dim + 1) % groups != 0) {
std::ostringstream msg;
msg << "[conv] Expect the input channels in the input"
<< " and weight array to match but got shapes -"
<< " input: " << in.shape() << " and weight: " << wt.shape();
msg << "[conv] The input channels must be divisible by the number"
<< " of groups. Got input with shape " << in.shape() << " and "
<< groups << " groups.";
throw std::invalid_argument(msg.str());
}
if (groups > 1 && wt.shape(0) % groups != 0) {
std::ostringstream msg;
msg << "[conv] If groups > 1, the output channels must be divisible by the number"
<< " of groups. Got " << wt.shape(0) << " output channels and "
<< groups << " groups.";
throw std::invalid_argument(msg.str());
}
if (in.shape(n_dim + 1) != (groups * wt.shape(n_dim + 1))) {
std::ostringstream msg;
if (groups == 1) {
msg << "[conv] Expect the input channels in the input"
<< " and weight array to match but got shapes -"
<< " input: " << in.shape() << " and weight: " << wt.shape();
} else {
msg << "Given groups=" << groups << " and weights of shape " << wt.shape()
<< ", expected to have " << (groups * wt.shape(n_dim + 1))
<< " input channels but got " << in.shape(n_dim + 1)
<< " input channels instead.";
}
throw std::invalid_argument(msg.str());
}
}
@@ -3039,8 +3064,9 @@ array conv_general(
bool flip /* = false */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1) {
throw std::invalid_argument("[conv] Cannot handle groups != 1 yet");
if (groups != 1 && in.ndim() != 3) {
throw std::invalid_argument(
"[conv] Can only handle groups != 1 in 1D convolutions.");
}
int spatial_dims = in.ndim() - 2;
@@ -3052,7 +3078,7 @@ array conv_general(
}
// Run checks
run_conv_checks(in, wt, spatial_dims);
run_conv_checks(in, wt, spatial_dims, groups);
// Type promotion
auto out_type = promote_types(in.dtype(), wt.dtype());