This commit is contained in:
Awni Hannun
2025-04-21 13:04:39 -07:00
committed by GitHub
parent dc4eada7f0
commit 79b527f45f
3 changed files with 107 additions and 0 deletions

View File

@@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
def test_vmap_conv(self):
# vmap input only
x = mx.random.uniform(shape=(2, 2, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4))
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.moveaxis(x, 0, 2)
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights only
x = mx.random.uniform(shape=(2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
w = mx.moveaxis(w, 0, 1)
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights and input
x = mx.random.uniform(shape=(3, 2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.random.uniform(shape=(2, 3, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4, 3))
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
self.assertTrue(mx.allclose(expected, out))
# Test with groups
x = mx.random.uniform(shape=(3, 2, 5, 8))
w = mx.random.uniform(shape=(3, 2, 3, 4))
def gconv(x, w):
return mx.conv1d(x, w, groups=2)
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__":
unittest.main()