mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	conv vmap (#2102)
This commit is contained in:
		| @@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) | ||||
|         self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) | ||||
|  | ||||
|     def test_vmap_conv(self): | ||||
|         # vmap input only | ||||
|         x = mx.random.uniform(shape=(2, 2, 5, 4)) | ||||
|         w = mx.random.uniform(shape=(8, 3, 4)) | ||||
|  | ||||
|         expected = mx.stack([mx.conv1d(xi, w) for xi in x]) | ||||
|         out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         x = mx.moveaxis(x, 0, 2) | ||||
|         out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         # vmap weights only | ||||
|         x = mx.random.uniform(shape=(2, 5, 4)) | ||||
|         w = mx.random.uniform(shape=(3, 8, 3, 4)) | ||||
|  | ||||
|         expected = mx.stack([mx.conv1d(x, wi) for wi in w]) | ||||
|         out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         w = mx.moveaxis(w, 0, 1) | ||||
|         out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         # vmap weights and input | ||||
|         x = mx.random.uniform(shape=(3, 2, 5, 4)) | ||||
|         w = mx.random.uniform(shape=(3, 8, 3, 4)) | ||||
|  | ||||
|         expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)]) | ||||
|         out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         x = mx.random.uniform(shape=(2, 3, 5, 4)) | ||||
|         w = mx.random.uniform(shape=(8, 3, 4, 3)) | ||||
|  | ||||
|         expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)]) | ||||
|         out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         # Test with groups | ||||
|         x = mx.random.uniform(shape=(3, 2, 5, 8)) | ||||
|         w = mx.random.uniform(shape=(3, 2, 3, 4)) | ||||
|  | ||||
|         def gconv(x, w): | ||||
|             return mx.conv1d(x, w, groups=2) | ||||
|  | ||||
|         expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)]) | ||||
|         out = mx.vmap(gconv, in_axes=(0, 0))(x, w) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun