mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add optional headers to `mx.fast.metal_kernel
` (#1358)
This commit is contained in:
@@ -551,7 +551,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_basic(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="basic",
|
||||
source="""
|
||||
@@ -567,7 +567,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
output_dtypes={"out1": mx.float32},
|
||||
stream=mx.gpu,
|
||||
)
|
||||
mx.allclose(out["out1"], a[:2, :2])
|
||||
self.assertTrue(mx.allclose(out["out1"], a))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_args(self):
|
||||
@@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
source_contig = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
||||
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
|
||||
# non contiguous
|
||||
@@ -646,6 +646,33 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_helper(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="helper",
|
||||
header="""
|
||||
template <typename T>
|
||||
T do_exp(T x) {
|
||||
return metal::precise::exp(x);
|
||||
}
|
||||
""",
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = do_exp(a[elem]);
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user