mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
This commit is contained in:
parent
8fa6b322b9
commit
bf17ab5002
31
mlx/ops.cpp
31
mlx/ops.cpp
@ -2656,9 +2656,40 @@ inline std::vector<int> conv_out_shape(
|
||||
std::vector<int> out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
for (; i < in_shape.size() - 1; i++) {
|
||||
if (pads[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative."
|
||||
<< " Got padding " << pads << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive."
|
||||
<< " Got strides " << strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Dilation sizes must be positive."
|
||||
<< " Got dilation " << dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
out_shape[i] = conv_out_axis_size(
|
||||
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
|
||||
|
||||
if (out_shape[i] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Spatial dimensions of input after padding "
|
||||
<< " cannot be smaller than weight spatial dimensions."
|
||||
<< " Got input with shape " << in_shape << " and padding " << pads
|
||||
<< " for weight of shape " << wt_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
out_shape[i] = O;
|
||||
|
||||
|
@ -106,7 +106,9 @@ class RoPE(Module):
|
||||
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
|
||||
half_D = D // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
|
||||
freqs = mx.exp(
|
||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||
)
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
|
||||
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
|
||||
|
@ -2930,6 +2930,10 @@ void init_ops(py::module_& m) {
|
||||
throw std::invalid_argument("[convolve] Inputs must be 1D.");
|
||||
}
|
||||
|
||||
if (a.size() == 0 || v.size() == 0) {
|
||||
throw std::invalid_argument("[convolve] Inputs cannot be empty.");
|
||||
}
|
||||
|
||||
array in = a.size() < v.size() ? v : a;
|
||||
array wt = a.size() < v.size() ? a : v;
|
||||
wt = slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s);
|
||||
|
Loading…
Reference in New Issue
Block a user