Add load_safe to conv general loader

This commit is contained in:
Angelos Katharopoulos
2025-06-03 10:13:16 -07:00
parent 5866b3857b
commit c9af09d118
7 changed files with 198 additions and 22 deletions

View File

@@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
def test_conv2d_unaligned_channels(self):
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(32, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(21, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
if __name__ == "__main__":
unittest.main()