implement Convolution::output_shape (#2601)

- pull conv_out_shape out for re-use
- add Conv::output_shape
- add e2e python tests confirming shapeless=True support and correctness

Updates #2599
This commit is contained in:
Josh Bleecher Snyder
2025-09-22 10:09:45 -07:00
committed by GitHub
parent ec2ab42888
commit aa9d44b3d4
4 changed files with 230 additions and 105 deletions

View File

@@ -346,6 +346,105 @@ class TestExportImport(mlx_tests.MLXTestCase):
expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out))
def test_export_conv_shapeless(self):
# Conv1d (NLC)
path = os.path.join(self.test_dir, "conv1d.mlxfn")
class M1(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)
def __call__(self, x):
return self.c(x)
m1 = M1()
mx.eval(m1.parameters())
def f1(x):
return m1(x)
x = mx.random.normal(shape=(4, 64, 3))
mx.export_function(path, f1, x, shapeless=True)
f1_imp = mx.import_function(path)
for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt)))
# Conv2d (NHWC)
path = os.path.join(self.test_dir, "conv2d.mlxfn")
class M2(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)
def __call__(self, x):
return self.c(x)
m2 = M2()
mx.eval(m2.parameters())
def f2(x):
return m2(x)
x = mx.random.normal(shape=(2, 32, 32, 3))
mx.export_function(path, f2, x, shapeless=True)
f2_imp = mx.import_function(path)
for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt)))
# Conv3d (NDHWC)
path = os.path.join(self.test_dir, "conv3d.mlxfn")
class M3(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False)
def __call__(self, x):
return self.c(x)
m3 = M3()
mx.eval(m3.parameters())
def f3(x):
return m3(x)
x = mx.random.normal(shape=(1, 8, 8, 8, 2))
mx.export_function(path, f3, x, shapeless=True)
f3_imp = mx.import_function(path)
for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt)))
# Grouped Conv2d (NHWC)
path = os.path.join(self.test_dir, "conv2d_grouped.mlxfn")
class MG(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv2d(
4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False
)
def __call__(self, x):
return self.c(x)
mg = MG()
mx.eval(mg.parameters())
def fg(x):
return mg(x)
x = mx.random.normal(shape=(2, 32, 32, 4))
mx.export_function(path, fg, x, shapeless=True)
fg_imp = mx.import_function(path)
for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt)))
def test_export_control_flow(self):
def fun(x, y):