Add optional headers to `mx.fast.metal_kernel` (#1358)

This commit is contained in:
Alex Barron
2024-08-26 21:45:45 -07:00
committed by GitHub
parent 5f7d19d1f5
commit 1d94ac3f90
4 changed files with 50 additions and 10 deletions

View File

@@ -551,7 +551,7 @@ class TestFast(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_basic(self):
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="basic",
source="""
@@ -567,7 +567,7 @@ class TestFast(mlx_tests.MLXTestCase):
output_dtypes={"out1": mx.float32},
stream=mx.gpu,
)
mx.allclose(out["out1"], a[:2, :2])
self.assertTrue(mx.allclose(out["out1"], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_args(self):
@@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase):
uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
"""
source_contig = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
"""
# non contiguous
@@ -646,6 +646,33 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_helper(self):
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
header="""
template <typename T>
T do_exp(T x) {
return metal::precise::exp(x);
}
""",
source="""
uint elem = thread_position_in_grid.x;
out1[elem] = do_exp(a[elem]);
""",
)
out = kernel(
inputs={"a": a},
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)},
output_dtypes={"out1": mx.float32},
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
if __name__ == "__main__":
unittest.main()