Simplifications for MLX C (#1396)

* simplifications for MLX C

* use vectors instead of map

* update examples
This commit is contained in:
Awni Hannun
2024-09-06 19:16:50 -07:00
committed by GitHub
parent 7cca1727af
commit ba3e913c7a
7 changed files with 334 additions and 331 deletions

View File

@@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise:
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs["out"]
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
@@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise:
The full function signature will be generated using:
* The keys and shapes/dtypes of ``inputs``
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The keys and values of ``output_shapes`` and ``output_dtypes``
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
@@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs["out"]
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
@@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel:
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs={"x": x, "grid": grid},
template={"T": x.dtype},
output_shapes={"out": out_shape},
output_dtypes={"out": x.dtype},
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs["out"]
return outputs[0]
For a reasonably sized input such as:
@@ -389,6 +395,8 @@ We can then implement the backwards pass as follows:
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
@@ -398,15 +406,15 @@ We can then implement the backwards pass as follows:
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs={"x": x, "grid": grid, "cotangent": cotangent},
template={"T": x.dtype},
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs["x_grad"], outputs["grid_grad"]
return outputs[0], outputs[1]
There's an even larger speed up for the vjp: