Add the test

This commit is contained in:
Angelos Katharopoulos 2025-08-19 23:52:27 -07:00
parent 2581a9ab85
commit 9b226a929e

View File

@ -1186,6 +1186,13 @@ class TestConv(mlx_tests.MLXTestCase):
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
def test_conv2d_large_filter_small_channels(self):
x = mx.random.normal(shape=(1, 181, 181, 1))
w = mx.random.normal(shape=(1, 182, 182, 1))
y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu)
y_hat = mx.conv2d(x, w, (1, 1), (1, 1))
self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()