Winograd Update for Small batches (#1803)

* Build in padding to Winograd kernels
* Add new fused Winograd kernel
* Enable weight flipping in Winograd kernels
This commit is contained in:
Jagrit Digani
2025-02-14 13:08:13 -08:00
committed by GitHub
parent 7aea5b1895
commit 2dc307f2e6
4 changed files with 505 additions and 86 deletions

View File

@@ -341,7 +341,7 @@ class TestConv(mlx_tests.MLXTestCase):
atol, rtol = 1e-1, 1e-3
else:
atol, rtol = 1e-5, 1e-6
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol, rtol=rtol))
for dtype in ("float32", "bfloat16"):
for N, C, O in (