mlx/python/tests/test_export_import.py

352 lines
10 KiB
Python

# Copyright © 2024 Apple Inc.
import gc
import os
import tempfile
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
class TestExportImport(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_basic_export_import(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
# Function with no inputs
def fun():
return mx.zeros((3, 3))
mx.export_function(path, fun)
imported = mx.import_function(path)
expected = fun()
(out,) = imported()
self.assertTrue(mx.array_equal(out, expected))
# Simple function with inputs
def fun(x):
return mx.abs(mx.sin(x))
inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
expected = fun(inputs)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Inputs in a list or tuple
def fun(x):
x = mx.abs(mx.sin(x))
return x
mx.export_function(path, fun, [inputs])
imported = mx.import_function(path)
expected = fun(inputs)
(out,) = imported([inputs])
self.assertTrue(mx.allclose(out, expected))
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
mx.export_function(path, fun, (inputs,))
imported = mx.import_function(path)
(out,) = imported((inputs,))
self.assertTrue(mx.allclose(out, expected))
# Outputs in a list
def fun(x):
return [mx.abs(mx.sin(x))]
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Outputs in a tuple
def fun(x):
return (mx.abs(mx.sin(x)),)
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Check throws on invalid inputs / outputs
def fun(x):
return mx.abs(x)
with self.assertRaises(ValueError):
mx.export_function(path, fun, "hi")
with self.assertRaises(ValueError):
mx.export_function(path, fun, mx.array(1.0), "hi")
def fun(x):
return mx.abs(x[0][0])
with self.assertRaises(ValueError):
mx.export_function(path, fun, [[mx.array(1.0)]])
def fun():
return (mx.zeros((3, 3)), 1)
with self.assertRaises(ValueError):
mx.export_function(path, fun)
def fun():
return (mx.zeros((3, 3)), [mx.zeros((3, 3))])
with self.assertRaises(ValueError):
mx.export_function(path, fun)
def fun(x, y):
return x + y
mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))
imported = mx.import_function(path)
with self.assertRaises(ValueError):
imported(mx.array(1.0), 1.0)
with self.assertRaises(ValueError):
imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))
with self.assertRaises(ValueError):
imported(mx.array(1.0), [mx.array(1.0)])
def test_export_random_sample(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
mx.random.seed(5)
def fun():
return mx.random.uniform(shape=(3,))
mx.export_function(path, fun)
imported = mx.import_function(path)
(out,) = imported()
mx.random.seed(5)
expected = fun()
self.assertTrue(mx.array_equal(out, expected))
def test_export_with_kwargs(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(x, z=None):
out = x
if z is not None:
out += z
return out
x = mx.array([1, 2, 3])
y = mx.array([1, 1, 0])
z = mx.array([2, 2, 2])
mx.export_function(path, fun, (x,), {"z": z})
imported_fun = mx.import_function(path)
with self.assertRaises(ValueError):
imported_fun(x, z)
with self.assertRaises(ValueError):
imported_fun(x, y=z)
with self.assertRaises(ValueError):
imported_fun((x,), {"y": z})
out = imported_fun(x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun((x,), {"z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
mx.export_function(path, fun, x, z=z)
imported_fun = mx.import_function(path)
out = imported_fun(x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun((x,), {"z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
# Only specify kwargs
mx.export_function(path, fun, x=x, z=z)
imported_fun = mx.import_function(path)
with self.assertRaises(ValueError):
out = imported_fun(x, z=z)[0]
out = imported_fun(x=x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun({"x": x, "z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
def test_export_variable_inputs(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(x, y, z=None):
out = x + y
if z is not None:
out += z
return out
with mx.exporter(path, fun) as exporter:
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))
with self.assertRaises(RuntimeError):
exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
imported_fun = mx.import_function(path)
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]
self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))
with self.assertRaises(ValueError):
imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
# A function with a large constant
constant = mx.zeros((16, 2048))
mx.eval(constant)
def fun(*args):
return constant + sum(args)
with mx.exporter(path, fun) as exporter:
for i in range(5):
exporter(*[mx.array(1)] * i)
# Check the exported file size < constant size + small amount
constants_size = constant.nbytes + 8192
self.assertTrue(os.path.getsize(path) < constants_size)
def test_leaks(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
mx.synchronize()
if mx.metal.is_available():
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.exporter(path, f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
def test_export_import_shapeless(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(*args):
return sum(args)
with mx.exporter(path, fun, shapeless=True) as exporter:
exporter(mx.array(1))
exporter(mx.array(1), mx.array(2))
exporter(mx.array(1), mx.array(2), mx.array(3))
f2 = mx.import_function(path)
self.assertEqual(f2(mx.array(1))[0].item(), 1)
self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2)
self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3)
with self.assertRaises(ValueError):
f2(mx.array(10), mx.array([5, 10, 20]))
def test_export_scatter_gather(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(a, b):
return mx.take_along_axis(a, b, axis=0)
x = mx.random.uniform(shape=(4, 4))
y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]])
mx.export_function(path, fun, (x, y))
imported_fun = mx.import_function(path)
expected = fun(x, y)
out = imported_fun(x, y)[0]
self.assertTrue(mx.array_equal(expected, out))
def fun(a, b, c):
return mx.put_along_axis(a, b, c, axis=0)
x = mx.random.uniform(shape=(4, 4))
y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]])
z = mx.random.uniform(shape=(2, 4))
mx.export_function(path, fun, (x, y, z))
imported_fun = mx.import_function(path)
expected = fun(x, y, z)
out = imported_fun(x, y, z)[0]
self.assertTrue(mx.array_equal(expected, out))
def test_export_conv(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
class Model(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(
3, 16, kernel_size=3, stride=1, padding=1, bias=False
)
self.c2 = nn.Conv2d(
16, 16, kernel_size=3, stride=2, padding=1, bias=False
)
self.c3 = nn.Conv2d(
16, 16, kernel_size=3, stride=1, padding=2, bias=False
)
def __call__(self, x):
return self.c3(self.c2(self.c1(x)))
model = Model()
mx.eval(model.parameters())
def forward(x):
return model(x)
input_data = mx.random.normal(shape=(4, 32, 32, 3))
mx.export_function(path, forward, input_data)
imported_fn = mx.import_function(path)
out = imported_fn(input_data)[0]
expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()