Add more checks and clearer error messages to conv operations (#563)

* Add more checks and clearer error messages to conv operations
This commit is contained in:
Jagrit Digani
2024-01-26 15:13:26 -08:00
committed by GitHub
parent 8fa6b322b9
commit bf17ab5002
3 changed files with 38 additions and 1 deletions

View File

@@ -2656,9 +2656,40 @@ inline std::vector<int> conv_out_shape(
std::vector<int> out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
for (; i < in_shape.size() - 1; i++) {
if (pads[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative."
<< " Got padding " << pads << ".";
throw std::invalid_argument(msg.str());
}
if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive."
<< " Got strides " << strides << ".";
throw std::invalid_argument(msg.str());
}
if (dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Dilation sizes must be positive."
<< " Got dilation " << dilation << ".";
throw std::invalid_argument(msg.str());
}
out_shape[i] = conv_out_axis_size(
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding "
<< " cannot be smaller than weight spatial dimensions."
<< " Got input with shape " << in_shape << " and padding " << pads
<< " for weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
out_shape[i] = O;