mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
244
python/tests/test_export_import.py
Normal file
244
python/tests/test_export_import.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestExportImport(mlx_tests.MLXTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_basic_export_import(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
# Function with no inputs
|
||||
def fun():
|
||||
return mx.zeros((3, 3))
|
||||
|
||||
mx.export_function(path, fun)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun()
|
||||
(out,) = imported()
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
# Simple function with inputs
|
||||
def fun(x):
|
||||
return mx.abs(mx.sin(x))
|
||||
|
||||
inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun(inputs)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Inputs in a list or tuple
|
||||
def fun(x):
|
||||
x = mx.abs(mx.sin(x))
|
||||
return x
|
||||
|
||||
mx.export_function(path, fun, [inputs])
|
||||
imported = mx.import_function(path)
|
||||
|
||||
expected = fun(inputs)
|
||||
(out,) = imported([inputs])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
mx.export_function(path, fun, (inputs,))
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported((inputs,))
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Outputs in a list
|
||||
def fun(x):
|
||||
return [mx.abs(mx.sin(x))]
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Outputs in a tuple
|
||||
def fun(x):
|
||||
return (mx.abs(mx.sin(x)),)
|
||||
|
||||
mx.export_function(path, fun, inputs)
|
||||
imported = mx.import_function(path)
|
||||
(out,) = imported(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Check throws on invalid inputs / outputs
|
||||
def fun(x):
|
||||
return mx.abs(x)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, "hi")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, mx.array(1.0), "hi")
|
||||
|
||||
def fun(x):
|
||||
return mx.abs(x[0][0])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun, [[mx.array(1.0)]])
|
||||
|
||||
def fun():
|
||||
return (mx.zeros((3, 3)), 1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun)
|
||||
|
||||
def fun():
|
||||
return (mx.zeros((3, 3)), [mx.zeros((3, 3))])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.export_function(path, fun)
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))
|
||||
imported = mx.import_function(path)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), 1.0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported(mx.array(1.0), [mx.array(1.0)])
|
||||
|
||||
def test_export_random_sample(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
mx.random.seed(5)
|
||||
|
||||
def fun():
|
||||
return mx.random.uniform(shape=(3,))
|
||||
|
||||
mx.export_function(path, fun)
|
||||
imported = mx.import_function(path)
|
||||
|
||||
(out,) = imported()
|
||||
|
||||
mx.random.seed(5)
|
||||
expected = fun()
|
||||
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_export_with_kwargs(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(x, z=None):
|
||||
out = x
|
||||
if z is not None:
|
||||
out += z
|
||||
return out
|
||||
|
||||
x = mx.array([1, 2, 3])
|
||||
y = mx.array([1, 1, 0])
|
||||
z = mx.array([2, 2, 2])
|
||||
|
||||
mx.export_function(path, fun, (x,), {"z": z})
|
||||
imported_fun = mx.import_function(path)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(x, z)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(x, y=z)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun((x,), {"y": z})
|
||||
|
||||
out = imported_fun(x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun((x,), {"z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
mx.export_function(path, fun, x, z=z)
|
||||
imported_fun = mx.import_function(path)
|
||||
out = imported_fun(x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun((x,), {"z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
# Only specify kwargs
|
||||
mx.export_function(path, fun, x=x, z=z)
|
||||
imported_fun = mx.import_function(path)
|
||||
with self.assertRaises(ValueError):
|
||||
out = imported_fun(x, z=z)[0]
|
||||
|
||||
out = imported_fun(x=x, z=z)[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
out = imported_fun({"x": x, "z": z})[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
|
||||
|
||||
def test_export_variable_inputs(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(x, y, z=None):
|
||||
out = x + y
|
||||
if z is not None:
|
||||
out += z
|
||||
return out
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))
|
||||
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||
|
||||
imported_fun = mx.import_function(path)
|
||||
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))
|
||||
|
||||
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]
|
||||
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
|
||||
|
||||
# A function with a large constant
|
||||
constant = mx.zeros((16, 2048))
|
||||
mx.eval(constant)
|
||||
|
||||
def fun(*args):
|
||||
return constant + sum(args)
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
for i in range(5):
|
||||
exporter(*[mx.array(1)] * i)
|
||||
|
||||
# Check the exported file size < constant size + small amount
|
||||
constants_size = constant.nbytes + 8192
|
||||
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user