mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 07:41:03 +08:00
Add the test
This commit is contained in:
parent
2581a9ab85
commit
9b226a929e
@ -1186,6 +1186,13 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
def test_conv2d_large_filter_small_channels(self):
|
||||
x = mx.random.normal(shape=(1, 181, 181, 1))
|
||||
w = mx.random.normal(shape=(1, 182, 182, 1))
|
||||
y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w, (1, 1), (1, 1))
|
||||
self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Loading…
Reference in New Issue
Block a user