mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 00:08:09 +08:00
implement Convolution::output_shape (#2601)
- pull conv_out_shape out for re-use - add Conv::output_shape - add e2e python tests confirming shapeless=True support and correctness Updates #2599
This commit is contained in:

committed by
GitHub

parent
ec2ab42888
commit
aa9d44b3d4
@@ -346,6 +346,105 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
expected = forward(input_data)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_export_conv_shapeless(self):
|
||||
# Conv1d (NLC)
|
||||
path = os.path.join(self.test_dir, "conv1d.mlxfn")
|
||||
|
||||
class M1(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
m1 = M1()
|
||||
mx.eval(m1.parameters())
|
||||
|
||||
def f1(x):
|
||||
return m1(x)
|
||||
|
||||
x = mx.random.normal(shape=(4, 64, 3))
|
||||
mx.export_function(path, f1, x, shapeless=True)
|
||||
f1_imp = mx.import_function(path)
|
||||
for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt)))
|
||||
|
||||
# Conv2d (NHWC)
|
||||
path = os.path.join(self.test_dir, "conv2d.mlxfn")
|
||||
|
||||
class M2(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
m2 = M2()
|
||||
mx.eval(m2.parameters())
|
||||
|
||||
def f2(x):
|
||||
return m2(x)
|
||||
|
||||
x = mx.random.normal(shape=(2, 32, 32, 3))
|
||||
mx.export_function(path, f2, x, shapeless=True)
|
||||
f2_imp = mx.import_function(path)
|
||||
for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt)))
|
||||
|
||||
# Conv3d (NDHWC)
|
||||
path = os.path.join(self.test_dir, "conv3d.mlxfn")
|
||||
|
||||
class M3(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
m3 = M3()
|
||||
mx.eval(m3.parameters())
|
||||
|
||||
def f3(x):
|
||||
return m3(x)
|
||||
|
||||
x = mx.random.normal(shape=(1, 8, 8, 8, 2))
|
||||
mx.export_function(path, f3, x, shapeless=True)
|
||||
f3_imp = mx.import_function(path)
|
||||
for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt)))
|
||||
|
||||
# Grouped Conv2d (NHWC)
|
||||
path = os.path.join(self.test_dir, "conv2d_grouped.mlxfn")
|
||||
|
||||
class MG(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv2d(
|
||||
4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
mg = MG()
|
||||
mx.eval(mg.parameters())
|
||||
|
||||
def fg(x):
|
||||
return mg(x)
|
||||
|
||||
x = mx.random.normal(shape=(2, 32, 32, 4))
|
||||
mx.export_function(path, fg, x, shapeless=True)
|
||||
fg_imp = mx.import_function(path)
|
||||
for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt)))
|
||||
|
||||
def test_export_control_flow(self):
|
||||
|
||||
def fun(x, y):
|
||||
|
Reference in New Issue
Block a user