mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 16:28:10 +08:00
Simplifications for MLX C (#1396)
* simplifications for MLX C * use vectors instead of map * update examples
This commit is contained in:
@@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
@@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
* The keys and shapes/dtypes of ``inputs``
|
||||
* The shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The keys and values of ``output_shapes`` and ``output_dtypes``
|
||||
* The list of ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
* Template parameters passed using ``template``
|
||||
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
|
||||
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||
@@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
# make non-contiguous
|
||||
@@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel:
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"out": out_shape},
|
||||
output_dtypes={"out": x.dtype},
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
@@ -389,6 +395,8 @@ We can then implement the backwards pass as follows:
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
@@ -398,15 +406,15 @@ We can then implement the backwards pass as follows:
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid, "cotangent": cotangent},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
|
||||
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs["x_grad"], outputs["grid_grad"]
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
Reference in New Issue
Block a user