mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Add load_safe to the general conv loaders (#2258)
This commit is contained in:

committed by
GitHub

parent
095163b8d1
commit
8590c0941e
@@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||
|
||||
def test_conv2d_unaligned_channels(self):
|
||||
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||
w = mx.random.uniform(shape=(32, 3, 3, 21))
|
||||
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||
w = mx.random.uniform(shape=(21, 3, 3, 21))
|
||||
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user