Fix export scatters (#2852)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-12-02 11:24:40 -08:00
committed by GitHub
parent 6c5785bc2f
commit eff0e31f00
3 changed files with 19 additions and 5 deletions

View File

@@ -382,6 +382,7 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(LogicalOr),
SERIALIZE_PRIMITIVE(LogAddExp),
SERIALIZE_PRIMITIVE(LogSumExp),
SERIALIZE_PRIMITIVE(MaskedScatter),
SERIALIZE_PRIMITIVE(Matmul),
SERIALIZE_PRIMITIVE(Maximum),
SERIALIZE_PRIMITIVE(Minimum),

View File

@@ -1871,13 +1871,13 @@ class Scatter : public UnaryPrimitive {
const char* name() const override {
switch (reduce_type_) {
case Sum:
return "ScatterSum";
return "Scatter Sum";
case Prod:
return "ScatterProd";
return "Scatter Prod";
case Min:
return "ScatterMin";
return "Scatter Min";
case Max:
return "ScatterMax";
return "Scatter Max";
case None:
return "Scatter";
}
@@ -1910,7 +1910,7 @@ class ScatterAxis : public UnaryPrimitive {
const char* name() const override {
switch (reduce_type_) {
case Sum:
return "ScatterAxisSum";
return "ScatterAxis Sum";
case None:
return "ScatterAxis";
}

View File

@@ -596,6 +596,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
for y in ys:
self.assertEqual(imported(y)[0].item(), fun(y).item())
def test_export_import_scatter_sum(self):
def fun(x, y, z):
return x.at[y].add(z)
x = mx.array([1, 2, 3])
y = mx.array([0, 0, 1])
z = mx.array([1, 1, 1])
path = os.path.join(self.test_dir, "fn.mlxfn")
mx.export_function(path, fun, x, y, z)
imported = mx.import_function(path)
self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z)))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()