mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix conv export (#2265)
This commit is contained in:
parent
7c4eb5d03e
commit
62fecf3e13
@ -719,9 +719,9 @@ class Convolution : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
|
kernel_strides_,
|
||||||
padding_lo_,
|
padding_lo_,
|
||||||
padding_hi_,
|
padding_hi_,
|
||||||
kernel_strides_,
|
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
groups_,
|
groups_,
|
||||||
|
@ -6,6 +6,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
@ -312,6 +313,39 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
out = imported_fun(x, y, z)[0]
|
out = imported_fun(x, y, z)[0]
|
||||||
self.assertTrue(mx.array_equal(expected, out))
|
self.assertTrue(mx.array_equal(expected, out))
|
||||||
|
|
||||||
|
def test_export_conv(self):
|
||||||
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.c1 = nn.Conv2d(
|
||||||
|
3, 16, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.c2 = nn.Conv2d(
|
||||||
|
16, 16, kernel_size=3, stride=2, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.c3 = nn.Conv2d(
|
||||||
|
16, 16, kernel_size=3, stride=1, padding=2, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.c3(self.c2(self.c1(x)))
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
def forward(x):
|
||||||
|
return model(x)
|
||||||
|
|
||||||
|
input_data = mx.random.normal(shape=(4, 32, 32, 3))
|
||||||
|
mx.export_function(path, forward, input_data)
|
||||||
|
|
||||||
|
imported_fn = mx.import_function(path)
|
||||||
|
out = imported_fn(input_data)[0]
|
||||||
|
expected = forward(input_data)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user