mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
conv vmap (#2102)
This commit is contained in:
@@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
||||
|
||||
def test_vmap_conv(self):
|
||||
# vmap input only
|
||||
x = mx.random.uniform(shape=(2, 2, 5, 4))
|
||||
w = mx.random.uniform(shape=(8, 3, 4))
|
||||
|
||||
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
x = mx.moveaxis(x, 0, 2)
|
||||
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# vmap weights only
|
||||
x = mx.random.uniform(shape=(2, 5, 4))
|
||||
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||
|
||||
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
w = mx.moveaxis(w, 0, 1)
|
||||
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# vmap weights and input
|
||||
x = mx.random.uniform(shape=(3, 2, 5, 4))
|
||||
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||
|
||||
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 5, 4))
|
||||
w = mx.random.uniform(shape=(8, 3, 4, 3))
|
||||
|
||||
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Test with groups
|
||||
x = mx.random.uniform(shape=(3, 2, 5, 8))
|
||||
w = mx.random.uniform(shape=(3, 2, 3, 4))
|
||||
|
||||
def gconv(x, w):
|
||||
return mx.conv1d(x, w, groups=2)
|
||||
|
||||
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
|
||||
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user