export custom kernel (#2756)

This commit is contained in:
Awni Hannun
2025-11-13 11:29:50 -08:00
committed by GitHub
parent 3f866be665
commit 8973550ff3
6 changed files with 161 additions and 37 deletions

View File

@@ -531,6 +531,50 @@ class TestExportImport(mlx_tests.MLXTestCase):
self.assertEqual(keywords[0][0], "y")
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_export_import_custom_kernel(self):
if mx.metal.is_available():
source = """
uint elem = thread_position_in_grid.x;
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.cuda_kernel
kernel = custom_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source=source,
)
def call(a):
return kernel(
inputs=[a],
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes=[(2, 2)],
output_dtypes=[mx.float32],
stream=mx.gpu,
)[0]
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
path = os.path.join(self.test_dir, "fn.mlxfn")
expected = call(a)
mx.export_function(path, call, a)
imported = mx.import_function(path)
out = imported(a)[0]
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()