mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
This commit is contained in:
31
mlx/ops.cpp
31
mlx/ops.cpp
@@ -2656,9 +2656,40 @@ inline std::vector<int> conv_out_shape(
|
||||
std::vector<int> out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
for (; i < in_shape.size() - 1; i++) {
|
||||
if (pads[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative."
|
||||
<< " Got padding " << pads << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive."
|
||||
<< " Got strides " << strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Dilation sizes must be positive."
|
||||
<< " Got dilation " << dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
out_shape[i] = conv_out_axis_size(
|
||||
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
|
||||
|
||||
if (out_shape[i] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Spatial dimensions of input after padding "
|
||||
<< " cannot be smaller than weight spatial dimensions."
|
||||
<< " Got input with shape " << in_shape << " and padding " << pads
|
||||
<< " for weight of shape " << wt_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
out_shape[i] = O;
|
||||
|
||||
|
Reference in New Issue
Block a user