Fix export scatters (#2852)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-12-02 11:24:40 -08:00
committed by GitHub
parent 6c5785bc2f
commit eff0e31f00
3 changed files with 19 additions and 5 deletions

View File

@@ -596,6 +596,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
for y in ys:
self.assertEqual(imported(y)[0].item(), fun(y).item())
def test_export_import_scatter_sum(self):
def fun(x, y, z):
return x.at[y].add(z)
x = mx.array([1, 2, 3])
y = mx.array([0, 0, 1])
z = mx.array([1, 1, 1])
path = os.path.join(self.test_dir, "fn.mlxfn")
mx.export_function(path, fun, x, y, z)
imported = mx.import_function(path)
self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z)))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()