diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 387f8e814..2012613e0 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2656,9 +2656,40 @@ inline std::vector conv_out_shape( std::vector out_shape(in_shape.size()); int i = 0; out_shape[i++] = N; + for (; i < in_shape.size() - 1; i++) { + if (pads[i - 1] < 0) { + std::ostringstream msg; + msg << "[conv] Padding sizes must be non-negative." + << " Got padding " << pads << "."; + throw std::invalid_argument(msg.str()); + } + + if (strides[i - 1] <= 0) { + std::ostringstream msg; + msg << "[conv] Stride sizes must be positive." + << " Got strides " << strides << "."; + throw std::invalid_argument(msg.str()); + } + + if (dilation[i - 1] <= 0) { + std::ostringstream msg; + msg << "[conv] Dilation sizes must be positive." + << " Got dilation " << dilation << "."; + throw std::invalid_argument(msg.str()); + } + out_shape[i] = conv_out_axis_size( in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]); + + if (out_shape[i] <= 0) { + std::ostringstream msg; + msg << "[conv] Spatial dimensions of input after padding " + << " cannot be smaller than weight spatial dimensions." + << " Got input with shape " << in_shape << " and padding " << pads + << " for weight of shape " << wt_shape << "."; + throw std::invalid_argument(msg.str()); + } } out_shape[i] = O; diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index 1935792a4..a8024f0a4 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -106,7 +106,9 @@ class RoPE(Module): if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: half_D = D // 2 positions = mx.arange(offset, N, dtype=dtype) * scale - freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) + freqs = mx.exp( + -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) + ) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta)) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 432b98d76..d622dcdf1 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2930,6 +2930,10 @@ void init_ops(py::module_& m) { throw std::invalid_argument("[convolve] Inputs must be 1D."); } + if (a.size() == 0 || v.size() == 0) { + throw std::invalid_argument("[convolve] Inputs cannot be empty."); + } + array in = a.size() < v.size() ? v : a; array wt = a.size() < v.size() ? a : v; wt = slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s);