mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 01:18:41 +08:00
Custom cuda kernel (#2517)
This commit is contained in:

committed by
GitHub

parent
f4c8888cbe
commit
e397177f6e
@@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)(x)
|
||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_basic(self):
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
kernel = custom_kernel(
|
||||
name="basic",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
""",
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[a],
|
||||
@@ -604,17 +614,10 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(out[0], a))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_args(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="arg_test",
|
||||
input_names=["a", "b", "c", "d"],
|
||||
output_names=["out1", "out2"],
|
||||
source="""
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = a[0];
|
||||
if (e) {
|
||||
@@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
out1[elem] = 1;
|
||||
}
|
||||
out2[elem] = a[1] + b[2] + c[1] - d;
|
||||
""",
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
T tmp = a[0];
|
||||
if (e) {
|
||||
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
|
||||
} else {
|
||||
out1[elem] = 1;
|
||||
}
|
||||
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
||||
|
||||
kernel = custom_kernel(
|
||||
name="arg_test",
|
||||
input_names=["a", "b", "c", "d"],
|
||||
output_names=["out1", "out2"],
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[
|
||||
@@ -647,27 +673,43 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_strides(self):
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
source_contig = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = exp(tmp) * WARP_SIZE;
|
||||
"""
|
||||
source_contig = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
T tmp = inp[elem];
|
||||
out[elem] = exp(tmp) * WARP_SIZE;
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
source_contig = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
|
||||
# non contiguous
|
||||
a = mx.tile(a[::2], [4, 1])
|
||||
|
||||
for contig in [True, False]:
|
||||
kernel = mx.fast.metal_kernel(
|
||||
kernel = custom_kernel(
|
||||
name="myexp" + str(contig),
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
@@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_helper(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="helper",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
header="""
|
||||
if mx.metal.is_available():
|
||||
header = """
|
||||
template <typename T>
|
||||
T do_exp(T x) {
|
||||
return metal::precise::exp(x);
|
||||
}
|
||||
""",
|
||||
source="""
|
||||
"""
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = do_exp(a[elem]);
|
||||
""",
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
header = """
|
||||
template <typename T>
|
||||
__device__ T do_exp(T x) {
|
||||
return exp(x);
|
||||
}
|
||||
"""
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
out1[elem] = do_exp(a[elem]);
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = custom_kernel(
|
||||
name="helper",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
header=header,
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[a],
|
||||
@@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_custom_kernel_attributes(self):
|
||||
if mx.metal.is_available():
|
||||
source = "out[0] = threads_per_threadgroup.x;"
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = "out[0] = blockDim.x;"
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
a = mx.zeros(shape=(1, 1))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
kernel = custom_kernel(
|
||||
name="test_fun",
|
||||
input_names=["a"],
|
||||
output_names=["out"],
|
||||
source="""
|
||||
out[0] = threads_per_threadgroup.x;
|
||||
""",
|
||||
source=source,
|
||||
)
|
||||
out = kernel(
|
||||
inputs=[a],
|
||||
|
Reference in New Issue
Block a user