mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Custom Metal Kernels from Python (#1325)
* start * simple kernels working * restructure * inverse example working * docs + fixes * missing file * fix imports * address comments * add docs + fix test * Review comments + refactor to a single function * update docs * remove hashing * fix contig bug in test * back to a class * trailing whitespace * fix tests * match c++ and python apis * add link + make args kw_only
This commit is contained in:
@@ -548,6 +548,104 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(w, w_p))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_basic(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="basic",
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
stream=mx.gpu,
|
||||
)
|
||||
mx.allclose(out["out1"], a[:2, :2])
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_args(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="arg_test",
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = a[0];
|
||||
if (e) {
|
||||
out1[elem] = a[1] + b[2] + c[3] + d + f;
|
||||
} else {
|
||||
out1[elem] = 1;
|
||||
}
|
||||
out2[elem] = a[1] + b[2] + c[1] - d;
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={
|
||||
"a": a,
|
||||
"b": mx.array([3, 4, 5]),
|
||||
"c": c,
|
||||
"d": 7.3,
|
||||
},
|
||||
template={
|
||||
"e": True,
|
||||
"f": 3,
|
||||
"T": mx.float16,
|
||||
},
|
||||
grid=(6, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2), "out2": (3, 2)},
|
||||
output_dtypes={"out1": mx.float32, "out2": mx.int32},
|
||||
stream=mx.gpu,
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_strides(self):
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(3, 6))
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
source_contig = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
# non contiguous
|
||||
a = mx.tile(a[::2], [4, 1])
|
||||
|
||||
for contig in [True, False]:
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp" + str(contig),
|
||||
source=source_contig if contig else source,
|
||||
ensure_row_contiguous=contig,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a), outputs["out"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user