Simplifications for MLX C (#1396)

* simplifications for MLX C

* use vectors instead of map

* update examples
This commit is contained in:
Awni Hannun
2024-09-06 19:16:50 -07:00
committed by GitHub
parent 7cca1727af
commit ba3e913c7a
7 changed files with 334 additions and 331 deletions

View File

@@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase):
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source="""
uint elem = thread_position_in_grid.x;
out1[elem] = a[elem];
""",
)
out = kernel(
inputs={"a": a},
inputs=[a],
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)},
output_dtypes={"out1": mx.float32},
output_shapes=[(2, 2)],
output_dtypes=[mx.float32],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], a))
self.assertTrue(mx.allclose(out[0], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_args(self):
@@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase):
kernel = mx.fast.metal_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source="""
uint elem = thread_position_in_grid.x;
T tmp = a[0];
@@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase):
""",
)
out = kernel(
inputs={
"a": a,
"b": mx.array([3, 4, 5]),
"c": c,
"d": 7.3,
},
template={
"e": True,
"f": 3,
"T": mx.float16,
},
inputs=[
a,
mx.array([3, 4, 5]),
c,
7.3,
],
template=[
("e", True),
("f", 3),
("T", mx.float16),
],
grid=(6, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2), "out2": (3, 2)},
output_dtypes={"out1": mx.float32, "out2": mx.int32},
output_shapes=[(2, 2), (3, 2)],
output_dtypes=[mx.float32, mx.int32],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484)))
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_strides(self):
@@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase):
for contig in [True, False]:
kernel = mx.fast.metal_kernel(
name="myexp" + str(contig),
input_names=["inp"],
output_names=["out"],
source=source_contig if contig else source,
ensure_row_contiguous=contig,
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
output_shapes=[a.shape],
output_dtypes=[a.dtype],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_helper(self):
@@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase):
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header="""
template <typename T>
T do_exp(T x) {
@@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase):
""",
)
out = kernel(
inputs={"a": a},
inputs=[a],
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)},
output_dtypes={"out1": mx.float32},
output_shapes=[(2, 2)],
output_dtypes=[mx.float32],
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
if __name__ == "__main__":