Add more checks and clearer error messages to conv operations (#563)

* Add more checks and clearer error messages to conv operations
This commit is contained in:
Jagrit Digani 2024-01-26 15:13:26 -08:00 committed by GitHub
parent 8fa6b322b9
commit bf17ab5002
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 1 deletions

View File

@ -2656,9 +2656,40 @@ inline std::vector<int> conv_out_shape(
std::vector<int> out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
for (; i < in_shape.size() - 1; i++) {
if (pads[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative."
<< " Got padding " << pads << ".";
throw std::invalid_argument(msg.str());
}
if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive."
<< " Got strides " << strides << ".";
throw std::invalid_argument(msg.str());
}
if (dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Dilation sizes must be positive."
<< " Got dilation " << dilation << ".";
throw std::invalid_argument(msg.str());
}
out_shape[i] = conv_out_axis_size(
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding "
<< " cannot be smaller than weight spatial dimensions."
<< " Got input with shape " << in_shape << " and padding " << pads
<< " for weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
out_shape[i] = O;

View File

@ -106,7 +106,9 @@ class RoPE(Module):
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
half_D = D // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))

View File

@ -2930,6 +2930,10 @@ void init_ops(py::module_& m) {
throw std::invalid_argument("[convolve] Inputs must be 1D.");
}
if (a.size() == 0 || v.size() == 0) {
throw std::invalid_argument("[convolve] Inputs cannot be empty.");
}
array in = a.size() < v.size() ? v : a;
array wt = a.size() < v.size() ? a : v;
wt = slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s);