Add load_safe to the general conv loaders (#2258)

This commit is contained in:
Angelos Katharopoulos
2025-06-10 20:58:16 -07:00
committed by GitHub
parent 095163b8d1
commit 8590c0941e
8 changed files with 302 additions and 22 deletions

View File

@@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
def test_conv2d_unaligned_channels(self):
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(32, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(21, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
if __name__ == "__main__":
unittest.main()