Fix overflow in large filter small channels (#2520)

This commit is contained in:
Angelos Katharopoulos 2025-08-20 08:03:29 -07:00 committed by GitHub
parent 512281781c
commit 25c1e03205
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 1 deletions

View File

@ -83,7 +83,7 @@ struct Conv2DInputBlockLoaderSmallChannels {
const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params;
short weight_hw;
int weight_hw;
const device T* src[n_rows];

View File

@ -1186,6 +1186,13 @@ class TestConv(mlx_tests.MLXTestCase):
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
def test_conv2d_large_filter_small_channels(self):
x = mx.random.normal(shape=(1, 181, 181, 1))
w = mx.random.normal(shape=(1, 182, 182, 1))
y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu)
y_hat = mx.conv2d(x, w, (1, 1), (1, 1))
self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()