* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
This commit is contained in:
Max-Heinrich Laves
2024-05-11 15:15:02 +02:00
committed by GitHub
parent a9f80d60f6
commit ff4223904d
10 changed files with 951 additions and 13 deletions

View File

@@ -2878,14 +2878,14 @@ inline std::vector<int> conv_out_shape(
if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << "for " << spatial_dims
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
@@ -3058,6 +3058,30 @@ array conv2d(
s);
}
/** 3D convolution with a filter */
array conv3d(
const array& in_,
const array& wt_,
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_general(
/* const array& input = */ in_,
/* const array& weight = */ wt_,
/* std::vector<int> stride = */
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
/* std::vector<int> padding = */
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
/* std::vector<int> kernel_dilation = */
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
/* std::vector<int> input_dilation = */ {1, 1, 1},
/* int groups = */ groups,
/* bool flip = */ false,
s);
}
/** General convolution with a filter */
array conv_general(
array in,
@@ -3078,9 +3102,9 @@ array conv_general(
int spatial_dims = in.ndim() - 2;
if (spatial_dims < 1 || spatial_dims > 2) {
if (spatial_dims < 1 || spatial_dims > 3) {
throw std::invalid_argument(
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
"[conv] Only works for inputs with 1-3 spatial dimensions."
" The inputs must be in the format [N, ..., C_in]");
}
@@ -3120,10 +3144,10 @@ array conv_general(
// Check for negative padding
bool has_neg_padding = false;
for (auto& pd : padding_lo) {
has_neg_padding = (pd < 0);
has_neg_padding |= (pd < 0);
}
for (auto& pd : padding_hi) {
has_neg_padding = (pd < 0);
has_neg_padding |= (pd < 0);
}
// Handle negative padding