mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix conv export (#2265)
This commit is contained in:
parent
7c4eb5d03e
commit
62fecf3e13
@ -719,9 +719,9 @@ class Convolution : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
|
@ -6,6 +6,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
|
||||
|
||||
@ -312,6 +313,39 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
out = imported_fun(x, y, z)[0]
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
def test_export_conv(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c1 = nn.Conv2d(
|
||||
3, 16, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.c2 = nn.Conv2d(
|
||||
16, 16, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
self.c3 = nn.Conv2d(
|
||||
16, 16, kernel_size=3, stride=1, padding=2, bias=False
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c3(self.c2(self.c1(x)))
|
||||
|
||||
model = Model()
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def forward(x):
|
||||
return model(x)
|
||||
|
||||
input_data = mx.random.normal(shape=(4, 32, 32, 3))
|
||||
mx.export_function(path, forward, input_data)
|
||||
|
||||
imported_fn = mx.import_function(path)
|
||||
out = imported_fn(input_data)[0]
|
||||
expected = forward(input_data)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user