From 9ce77798b1588ee58d0366ad8e218327b663d24e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 9 Jun 2025 20:37:27 -0700 Subject: [PATCH] fix export to work with gather/scatter axis (#2263) --- mlx/export.cpp | 2 ++ python/tests/test_export_import.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/mlx/export.cpp b/mlx/export.cpp index bd2f24ba2..552c35cfb 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -266,6 +266,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Floor), SERIALIZE_PRIMITIVE(Full), SERIALIZE_PRIMITIVE(Gather), + SERIALIZE_PRIMITIVE(GatherAxis), SERIALIZE_PRIMITIVE(GatherMM), SERIALIZE_PRIMITIVE(Greater), SERIALIZE_PRIMITIVE(GreaterEqual), @@ -307,6 +308,7 @@ struct PrimitiveFactory { "CumMax", "CumLogaddexp"), SERIALIZE_PRIMITIVE(Scatter), + SERIALIZE_PRIMITIVE(ScatterAxis), SERIALIZE_PRIMITIVE(Select), SERIALIZE_PRIMITIVE(Sigmoid), SERIALIZE_PRIMITIVE(Sign), diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 0190827bd..ef9827cbe 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -286,6 +286,32 @@ class TestExportImport(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): f2(mx.array(10), mx.array([5, 10, 20])) + def test_export_scatter_gather(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(a, b): + return mx.take_along_axis(a, b, axis=0) + + x = mx.random.uniform(shape=(4, 4)) + y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) + mx.export_function(path, fun, (x, y)) + imported_fun = mx.import_function(path) + expected = fun(x, y) + out = imported_fun(x, y)[0] + self.assertTrue(mx.array_equal(expected, out)) + + def fun(a, b, c): + return mx.put_along_axis(a, b, c, axis=0) + + x = mx.random.uniform(shape=(4, 4)) + y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) + z = mx.random.uniform(shape=(2, 4)) + mx.export_function(path, fun, (x, y, z)) + imported_fun = mx.import_function(path) + expected = fun(x, y, z) + out = imported_fun(x, y, z)[0] + self.assertTrue(mx.array_equal(expected, out)) + if __name__ == "__main__": unittest.main()