Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -587,10 +587,10 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
for idim, kdim, stride, padding, dilation in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)),
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
):
run_conv_transpose3D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype