mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 09:29:26 +08:00
export custom kernel (#2756)
This commit is contained in:
@@ -531,6 +531,50 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(keywords[0][0], "y")
|
||||
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])
|
||||
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_export_import_custom_kernel(self):
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
kernel = custom_kernel(
|
||||
name="basic",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def call(a):
|
||||
return kernel(
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)[0]
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
expected = call(a)
|
||||
mx.export_function(path, call, a)
|
||||
|
||||
imported = mx.import_function(path)
|
||||
|
||||
out = imported(a)[0]
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user