mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Simplifications for MLX C (#1396)
* simplifications for MLX C * use vectors instead of map * update examples
This commit is contained in:
@@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="basic",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(out["out1"], a))
|
||||
self.assertTrue(mx.allclose(out[0], a))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_args(self):
|
||||
@@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="arg_test",
|
||||
input_names=["a", "b", "c", "d"],
|
||||
output_names=["out1", "out2"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = a[0];
|
||||
@@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={
|
||||
"a": a,
|
||||
"b": mx.array([3, 4, 5]),
|
||||
"c": c,
|
||||
"d": 7.3,
|
||||
},
|
||||
template={
|
||||
"e": True,
|
||||
"f": 3,
|
||||
"T": mx.float16,
|
||||
},
|
||||
inputs=[
|
||||
a,
|
||||
mx.array([3, 4, 5]),
|
||||
c,
|
||||
7.3,
|
||||
],
|
||||
template=[
|
||||
("e", True),
|
||||
("f", 3),
|
||||
("T", mx.float16),
|
||||
],
|
||||
grid=(6, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2), "out2": (3, 2)},
|
||||
output_dtypes={"out1": mx.float32, "out2": mx.int32},
|
||||
output_shapes=[(2, 2), (3, 2)],
|
||||
output_dtypes=[mx.float32, mx.int32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_strides(self):
|
||||
@@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
for contig in [True, False]:
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp" + str(contig),
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source_contig if contig else source,
|
||||
ensure_row_contiguous=contig,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_helper(self):
|
||||
@@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="helper",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
header="""
|
||||
template <typename T>
|
||||
T do_exp(T x) {
|
||||
@@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
|
||||
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user