diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h index b0b98d21a..2312e1ca6 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -83,7 +83,7 @@ struct Conv2DInputBlockLoaderSmallChannels { const constant MLXConvParams<2>* params; const constant ImplicitGemmConv2DParams* gemm_params; - short weight_hw; + int weight_hw; const device T* src[n_rows]; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 9be22e01b..cef912aee 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1186,6 +1186,13 @@ class TestConv(mlx_tests.MLXTestCase): y_hat = mx.conv2d(x, w) self.assertTrue(mx.allclose(y, y_hat)) + def test_conv2d_large_filter_small_channels(self): + x = mx.random.normal(shape=(1, 181, 181, 1)) + w = mx.random.normal(shape=(1, 182, 182, 1)) + y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu) + y_hat = mx.conv2d(x, w, (1, 1), (1, 1)) + self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3)) + if __name__ == "__main__": mlx_tests.MLXTestRunner()