From f764bccbf0178e662de7c810231c7dab983b3281 Mon Sep 17 00:00:00 2001 From: CircleCI Docs Date: Tue, 1 Jul 2025 22:14:26 +0000 Subject: [PATCH] rebase --- docs/build/html/.buildinfo | 2 +- .../_sources/dev/custom_metal_kernels.rst | 498 +++++++++--------- docs/build/html/_sources/dev/extensions.rst | 6 +- docs/build/html/_sources/install.rst | 59 +++ docs/build/html/_sources/usage/indexing.rst | 10 + .../html/_static/documentation_options.js | 2 +- docs/build/html/cpp/ops.html | 8 +- docs/build/html/dev/custom_metal_kernels.html | 422 ++++++++------- docs/build/html/dev/extensions.html | 14 +- docs/build/html/dev/metal_debugger.html | 8 +- docs/build/html/dev/mlx_in_cpp.html | 8 +- .../html/examples/linear_regression.html | 8 +- docs/build/html/examples/llama-inference.html | 8 +- docs/build/html/examples/mlp.html | 8 +- docs/build/html/genindex.html | 8 +- docs/build/html/index.html | 8 +- docs/build/html/install.html | 68 ++- docs/build/html/objects.inv | Bin 28028 -> 28045 bytes .../python/_autosummary/mlx.core.Device.html | 8 +- .../python/_autosummary/mlx.core.Dtype.html | 8 +- .../_autosummary/mlx.core.DtypeCategory.html | 8 +- .../python/_autosummary/mlx.core.abs.html | 8 +- .../python/_autosummary/mlx.core.add.html | 8 +- .../python/_autosummary/mlx.core.addmm.html | 8 +- .../python/_autosummary/mlx.core.all.html | 8 +- .../_autosummary/mlx.core.allclose.html | 8 +- .../python/_autosummary/mlx.core.any.html | 8 +- .../python/_autosummary/mlx.core.arange.html | 8 +- .../python/_autosummary/mlx.core.arccos.html | 8 +- .../python/_autosummary/mlx.core.arccosh.html | 8 +- .../python/_autosummary/mlx.core.arcsin.html | 8 +- .../python/_autosummary/mlx.core.arcsinh.html | 8 +- .../python/_autosummary/mlx.core.arctan.html | 8 +- .../python/_autosummary/mlx.core.arctan2.html | 8 +- .../python/_autosummary/mlx.core.arctanh.html | 8 +- .../python/_autosummary/mlx.core.argmax.html | 8 +- .../python/_autosummary/mlx.core.argmin.html | 8 +- .../_autosummary/mlx.core.argpartition.html | 8 +- .../python/_autosummary/mlx.core.argsort.html | 8 +- .../python/_autosummary/mlx.core.array.T.html | 8 +- .../_autosummary/mlx.core.array.abs.html | 8 +- .../_autosummary/mlx.core.array.all.html | 8 +- .../_autosummary/mlx.core.array.any.html | 8 +- .../_autosummary/mlx.core.array.argmax.html | 8 +- .../_autosummary/mlx.core.array.argmin.html | 8 +- .../_autosummary/mlx.core.array.astype.html | 8 +- .../_autosummary/mlx.core.array.at.html | 8 +- .../_autosummary/mlx.core.array.conj.html | 8 +- .../_autosummary/mlx.core.array.cos.html | 8 +- .../_autosummary/mlx.core.array.cummax.html | 8 +- .../_autosummary/mlx.core.array.cummin.html | 8 +- .../_autosummary/mlx.core.array.cumprod.html | 8 +- .../_autosummary/mlx.core.array.cumsum.html | 8 +- .../_autosummary/mlx.core.array.diag.html | 8 +- .../_autosummary/mlx.core.array.diagonal.html | 8 +- .../_autosummary/mlx.core.array.dtype.html | 8 +- .../_autosummary/mlx.core.array.exp.html | 8 +- .../_autosummary/mlx.core.array.flatten.html | 8 +- .../python/_autosummary/mlx.core.array.html | 10 +- .../_autosummary/mlx.core.array.imag.html | 8 +- .../_autosummary/mlx.core.array.item.html | 8 +- .../_autosummary/mlx.core.array.itemsize.html | 8 +- .../_autosummary/mlx.core.array.log.html | 8 +- .../_autosummary/mlx.core.array.log10.html | 8 +- .../_autosummary/mlx.core.array.log1p.html | 8 +- .../_autosummary/mlx.core.array.log2.html | 8 +- .../mlx.core.array.logcumsumexp.html | 8 +- .../mlx.core.array.logsumexp.html | 8 +- .../_autosummary/mlx.core.array.max.html | 8 +- .../_autosummary/mlx.core.array.mean.html | 8 +- .../_autosummary/mlx.core.array.min.html | 8 +- .../_autosummary/mlx.core.array.moveaxis.html | 8 +- .../_autosummary/mlx.core.array.nbytes.html | 8 +- .../_autosummary/mlx.core.array.ndim.html | 8 +- .../_autosummary/mlx.core.array.prod.html | 8 +- .../_autosummary/mlx.core.array.real.html | 8 +- .../mlx.core.array.reciprocal.html | 8 +- .../_autosummary/mlx.core.array.reshape.html | 8 +- .../_autosummary/mlx.core.array.round.html | 8 +- .../_autosummary/mlx.core.array.rsqrt.html | 8 +- .../_autosummary/mlx.core.array.shape.html | 8 +- .../_autosummary/mlx.core.array.sin.html | 8 +- .../_autosummary/mlx.core.array.size.html | 8 +- .../_autosummary/mlx.core.array.split.html | 8 +- .../_autosummary/mlx.core.array.sqrt.html | 8 +- .../_autosummary/mlx.core.array.square.html | 8 +- .../_autosummary/mlx.core.array.squeeze.html | 8 +- .../_autosummary/mlx.core.array.std.html | 8 +- .../_autosummary/mlx.core.array.sum.html | 8 +- .../_autosummary/mlx.core.array.swapaxes.html | 8 +- .../_autosummary/mlx.core.array.tolist.html | 8 +- .../mlx.core.array.transpose.html | 8 +- .../_autosummary/mlx.core.array.var.html | 8 +- .../_autosummary/mlx.core.array.view.html | 8 +- .../_autosummary/mlx.core.array_equal.html | 8 +- .../_autosummary/mlx.core.as_strided.html | 8 +- .../_autosummary/mlx.core.async_eval.html | 8 +- .../_autosummary/mlx.core.atleast_1d.html | 8 +- .../_autosummary/mlx.core.atleast_2d.html | 8 +- .../_autosummary/mlx.core.atleast_3d.html | 8 +- .../_autosummary/mlx.core.bitwise_and.html | 8 +- .../_autosummary/mlx.core.bitwise_invert.html | 8 +- .../_autosummary/mlx.core.bitwise_or.html | 8 +- .../_autosummary/mlx.core.bitwise_xor.html | 8 +- .../mlx.core.block_masked_mm.html | 8 +- .../mlx.core.broadcast_arrays.html | 8 +- .../_autosummary/mlx.core.broadcast_to.html | 8 +- .../python/_autosummary/mlx.core.ceil.html | 8 +- .../_autosummary/mlx.core.clear_cache.html | 8 +- .../python/_autosummary/mlx.core.clip.html | 8 +- .../python/_autosummary/mlx.core.compile.html | 8 +- .../_autosummary/mlx.core.concatenate.html | 8 +- .../python/_autosummary/mlx.core.conj.html | 8 +- .../_autosummary/mlx.core.conjugate.html | 8 +- .../_autosummary/mlx.core.contiguous.html | 8 +- .../python/_autosummary/mlx.core.conv1d.html | 8 +- .../python/_autosummary/mlx.core.conv2d.html | 8 +- .../python/_autosummary/mlx.core.conv3d.html | 8 +- .../_autosummary/mlx.core.conv_general.html | 8 +- .../mlx.core.conv_transpose1d.html | 8 +- .../mlx.core.conv_transpose2d.html | 8 +- .../mlx.core.conv_transpose3d.html | 8 +- .../_autosummary/mlx.core.convolve.html | 8 +- .../python/_autosummary/mlx.core.cos.html | 8 +- .../python/_autosummary/mlx.core.cosh.html | 8 +- .../python/_autosummary/mlx.core.cummax.html | 8 +- .../python/_autosummary/mlx.core.cummin.html | 8 +- .../python/_autosummary/mlx.core.cumprod.html | 8 +- .../python/_autosummary/mlx.core.cumsum.html | 8 +- .../mlx.core.custom_function.html | 8 +- .../_autosummary/mlx.core.default_device.html | 8 +- .../_autosummary/mlx.core.default_stream.html | 8 +- .../python/_autosummary/mlx.core.degrees.html | 8 +- .../_autosummary/mlx.core.dequantize.html | 8 +- .../python/_autosummary/mlx.core.diag.html | 8 +- .../_autosummary/mlx.core.diagonal.html | 8 +- .../mlx.core.disable_compile.html | 8 +- .../mlx.core.distributed.Group.html | 8 +- .../mlx.core.distributed.all_gather.html | 8 +- .../mlx.core.distributed.all_sum.html | 8 +- .../mlx.core.distributed.init.html | 8 +- .../mlx.core.distributed.is_available.html | 8 +- .../mlx.core.distributed.recv.html | 8 +- .../mlx.core.distributed.recv_like.html | 8 +- .../mlx.core.distributed.send.html | 8 +- .../python/_autosummary/mlx.core.divide.html | 8 +- .../python/_autosummary/mlx.core.divmod.html | 8 +- .../python/_autosummary/mlx.core.einsum.html | 8 +- .../_autosummary/mlx.core.einsum_path.html | 8 +- .../_autosummary/mlx.core.enable_compile.html | 8 +- .../python/_autosummary/mlx.core.equal.html | 8 +- .../python/_autosummary/mlx.core.erf.html | 8 +- .../python/_autosummary/mlx.core.erfinv.html | 8 +- .../python/_autosummary/mlx.core.eval.html | 8 +- .../python/_autosummary/mlx.core.exp.html | 8 +- .../_autosummary/mlx.core.expand_dims.html | 8 +- .../python/_autosummary/mlx.core.expm1.html | 8 +- .../mlx.core.export_function.html | 8 +- .../_autosummary/mlx.core.export_to_dot.html | 8 +- .../_autosummary/mlx.core.exporter.html | 8 +- .../python/_autosummary/mlx.core.eye.html | 8 +- .../mlx.core.fast.layer_norm.html | 8 +- .../mlx.core.fast.metal_kernel.html | 8 +- .../_autosummary/mlx.core.fast.rms_norm.html | 8 +- .../_autosummary/mlx.core.fast.rope.html | 8 +- ...ore.fast.scaled_dot_product_attention.html | 8 +- .../python/_autosummary/mlx.core.fft.fft.html | 8 +- .../_autosummary/mlx.core.fft.fft2.html | 8 +- .../_autosummary/mlx.core.fft.fftn.html | 8 +- .../_autosummary/mlx.core.fft.fftshift.html | 8 +- .../_autosummary/mlx.core.fft.ifft.html | 8 +- .../_autosummary/mlx.core.fft.ifft2.html | 8 +- .../_autosummary/mlx.core.fft.ifftn.html | 8 +- .../_autosummary/mlx.core.fft.ifftshift.html | 8 +- .../_autosummary/mlx.core.fft.irfft.html | 8 +- .../_autosummary/mlx.core.fft.irfft2.html | 8 +- .../_autosummary/mlx.core.fft.irfftn.html | 8 +- .../_autosummary/mlx.core.fft.rfft.html | 8 +- .../_autosummary/mlx.core.fft.rfft2.html | 8 +- .../_autosummary/mlx.core.fft.rfftn.html | 8 +- .../python/_autosummary/mlx.core.finfo.html | 8 +- .../python/_autosummary/mlx.core.flatten.html | 8 +- .../python/_autosummary/mlx.core.floor.html | 8 +- .../_autosummary/mlx.core.floor_divide.html | 8 +- .../python/_autosummary/mlx.core.full.html | 8 +- .../_autosummary/mlx.core.gather_mm.html | 8 +- .../_autosummary/mlx.core.gather_qmm.html | 8 +- .../mlx.core.get_active_memory.html | 8 +- .../mlx.core.get_cache_memory.html | 8 +- .../mlx.core.get_peak_memory.html | 8 +- .../python/_autosummary/mlx.core.grad.html | 8 +- .../python/_autosummary/mlx.core.greater.html | 8 +- .../_autosummary/mlx.core.greater_equal.html | 8 +- .../mlx.core.hadamard_transform.html | 8 +- .../_autosummary/mlx.core.identity.html | 8 +- .../python/_autosummary/mlx.core.imag.html | 8 +- .../mlx.core.import_function.html | 8 +- .../python/_autosummary/mlx.core.inner.html | 8 +- .../python/_autosummary/mlx.core.isclose.html | 8 +- .../_autosummary/mlx.core.isfinite.html | 8 +- .../python/_autosummary/mlx.core.isinf.html | 8 +- .../python/_autosummary/mlx.core.isnan.html | 8 +- .../_autosummary/mlx.core.isneginf.html | 8 +- .../_autosummary/mlx.core.isposinf.html | 8 +- .../_autosummary/mlx.core.issubdtype.html | 8 +- .../python/_autosummary/mlx.core.jvp.html | 8 +- .../python/_autosummary/mlx.core.kron.html | 8 +- .../_autosummary/mlx.core.left_shift.html | 8 +- .../python/_autosummary/mlx.core.less.html | 8 +- .../_autosummary/mlx.core.less_equal.html | 8 +- .../mlx.core.linalg.cholesky.html | 8 +- .../mlx.core.linalg.cholesky_inv.html | 8 +- .../_autosummary/mlx.core.linalg.cross.html | 8 +- .../_autosummary/mlx.core.linalg.eig.html | 10 +- .../_autosummary/mlx.core.linalg.eigh.html | 8 +- .../_autosummary/mlx.core.linalg.eigvals.html | 10 +- .../mlx.core.linalg.eigvalsh.html | 8 +- .../_autosummary/mlx.core.linalg.inv.html | 8 +- .../_autosummary/mlx.core.linalg.lu.html | 8 +- .../mlx.core.linalg.lu_factor.html | 8 +- .../_autosummary/mlx.core.linalg.norm.html | 8 +- .../_autosummary/mlx.core.linalg.pinv.html | 8 +- .../_autosummary/mlx.core.linalg.qr.html | 8 +- .../_autosummary/mlx.core.linalg.solve.html | 8 +- .../mlx.core.linalg.solve_triangular.html | 8 +- .../_autosummary/mlx.core.linalg.svd.html | 8 +- .../_autosummary/mlx.core.linalg.tri_inv.html | 8 +- .../_autosummary/mlx.core.linspace.html | 8 +- .../python/_autosummary/mlx.core.load.html | 8 +- .../python/_autosummary/mlx.core.log.html | 8 +- .../python/_autosummary/mlx.core.log10.html | 8 +- .../python/_autosummary/mlx.core.log1p.html | 8 +- .../python/_autosummary/mlx.core.log2.html | 8 +- .../_autosummary/mlx.core.logaddexp.html | 8 +- .../_autosummary/mlx.core.logcumsumexp.html | 8 +- .../_autosummary/mlx.core.logical_and.html | 8 +- .../_autosummary/mlx.core.logical_not.html | 8 +- .../_autosummary/mlx.core.logical_or.html | 8 +- .../_autosummary/mlx.core.logsumexp.html | 8 +- .../python/_autosummary/mlx.core.matmul.html | 8 +- .../python/_autosummary/mlx.core.max.html | 8 +- .../python/_autosummary/mlx.core.maximum.html | 8 +- .../python/_autosummary/mlx.core.mean.html | 8 +- .../_autosummary/mlx.core.meshgrid.html | 8 +- .../mlx.core.metal.device_info.html | 8 +- .../mlx.core.metal.is_available.html | 8 +- .../mlx.core.metal.start_capture.html | 8 +- .../mlx.core.metal.stop_capture.html | 8 +- .../python/_autosummary/mlx.core.min.html | 8 +- .../python/_autosummary/mlx.core.minimum.html | 8 +- .../_autosummary/mlx.core.moveaxis.html | 8 +- .../_autosummary/mlx.core.multiply.html | 8 +- .../_autosummary/mlx.core.nan_to_num.html | 8 +- .../_autosummary/mlx.core.negative.html | 8 +- .../_autosummary/mlx.core.new_stream.html | 8 +- .../_autosummary/mlx.core.not_equal.html | 8 +- .../python/_autosummary/mlx.core.ones.html | 8 +- .../_autosummary/mlx.core.ones_like.html | 8 +- .../python/_autosummary/mlx.core.outer.html | 8 +- .../python/_autosummary/mlx.core.pad.html | 8 +- .../_autosummary/mlx.core.partition.html | 8 +- .../python/_autosummary/mlx.core.power.html | 8 +- .../python/_autosummary/mlx.core.prod.html | 8 +- .../_autosummary/mlx.core.put_along_axis.html | 8 +- .../_autosummary/mlx.core.quantize.html | 8 +- .../mlx.core.quantized_matmul.html | 8 +- .../python/_autosummary/mlx.core.radians.html | 8 +- .../mlx.core.random.bernoulli.html | 8 +- .../mlx.core.random.categorical.html | 8 +- .../_autosummary/mlx.core.random.gumbel.html | 8 +- .../_autosummary/mlx.core.random.key.html | 8 +- .../_autosummary/mlx.core.random.laplace.html | 8 +- .../mlx.core.random.multivariate_normal.html | 8 +- .../_autosummary/mlx.core.random.normal.html | 8 +- .../mlx.core.random.permutation.html | 8 +- .../_autosummary/mlx.core.random.randint.html | 8 +- .../_autosummary/mlx.core.random.seed.html | 8 +- .../_autosummary/mlx.core.random.split.html | 8 +- .../mlx.core.random.truncated_normal.html | 8 +- .../_autosummary/mlx.core.random.uniform.html | 8 +- .../python/_autosummary/mlx.core.real.html | 8 +- .../_autosummary/mlx.core.reciprocal.html | 8 +- .../_autosummary/mlx.core.remainder.html | 8 +- .../python/_autosummary/mlx.core.repeat.html | 8 +- .../mlx.core.reset_peak_memory.html | 8 +- .../python/_autosummary/mlx.core.reshape.html | 8 +- .../_autosummary/mlx.core.right_shift.html | 8 +- .../python/_autosummary/mlx.core.roll.html | 8 +- .../python/_autosummary/mlx.core.round.html | 8 +- .../python/_autosummary/mlx.core.rsqrt.html | 8 +- .../python/_autosummary/mlx.core.save.html | 8 +- .../_autosummary/mlx.core.save_gguf.html | 8 +- .../mlx.core.save_safetensors.html | 8 +- .../python/_autosummary/mlx.core.savez.html | 8 +- .../mlx.core.savez_compressed.html | 8 +- .../mlx.core.set_cache_limit.html | 8 +- .../mlx.core.set_default_device.html | 8 +- .../mlx.core.set_default_stream.html | 8 +- .../mlx.core.set_memory_limit.html | 8 +- .../mlx.core.set_wired_limit.html | 8 +- .../python/_autosummary/mlx.core.sigmoid.html | 8 +- .../python/_autosummary/mlx.core.sign.html | 8 +- .../python/_autosummary/mlx.core.sin.html | 8 +- .../python/_autosummary/mlx.core.sinh.html | 8 +- .../python/_autosummary/mlx.core.slice.html | 8 +- .../_autosummary/mlx.core.slice_update.html | 8 +- .../python/_autosummary/mlx.core.softmax.html | 8 +- .../python/_autosummary/mlx.core.sort.html | 8 +- .../python/_autosummary/mlx.core.split.html | 8 +- .../python/_autosummary/mlx.core.sqrt.html | 8 +- .../python/_autosummary/mlx.core.square.html | 8 +- .../python/_autosummary/mlx.core.squeeze.html | 8 +- .../python/_autosummary/mlx.core.stack.html | 8 +- .../python/_autosummary/mlx.core.std.html | 8 +- .../_autosummary/mlx.core.stop_gradient.html | 8 +- .../python/_autosummary/mlx.core.stream.html | 8 +- .../_autosummary/mlx.core.subtract.html | 8 +- .../python/_autosummary/mlx.core.sum.html | 8 +- .../_autosummary/mlx.core.swapaxes.html | 8 +- .../_autosummary/mlx.core.synchronize.html | 8 +- .../python/_autosummary/mlx.core.take.html | 8 +- .../mlx.core.take_along_axis.html | 8 +- .../python/_autosummary/mlx.core.tan.html | 8 +- .../python/_autosummary/mlx.core.tanh.html | 8 +- .../_autosummary/mlx.core.tensordot.html | 8 +- .../python/_autosummary/mlx.core.tile.html | 8 +- .../python/_autosummary/mlx.core.topk.html | 8 +- .../python/_autosummary/mlx.core.trace.html | 8 +- .../_autosummary/mlx.core.transpose.html | 8 +- .../python/_autosummary/mlx.core.tri.html | 8 +- .../python/_autosummary/mlx.core.tril.html | 8 +- .../python/_autosummary/mlx.core.triu.html | 8 +- .../_autosummary/mlx.core.unflatten.html | 8 +- .../_autosummary/mlx.core.value_and_grad.html | 8 +- .../python/_autosummary/mlx.core.var.html | 8 +- .../python/_autosummary/mlx.core.view.html | 8 +- .../python/_autosummary/mlx.core.vjp.html | 8 +- .../python/_autosummary/mlx.core.vmap.html | 8 +- .../python/_autosummary/mlx.core.where.html | 8 +- .../python/_autosummary/mlx.core.zeros.html | 8 +- .../_autosummary/mlx.core.zeros_like.html | 8 +- .../mlx.nn.average_gradients.html | 8 +- .../python/_autosummary/mlx.nn.quantize.html | 8 +- .../_autosummary/mlx.nn.value_and_grad.html | 8 +- .../mlx.optimizers.clip_grad_norm.html | 8 +- .../_autosummary/mlx.utils.tree_flatten.html | 8 +- .../_autosummary/mlx.utils.tree_map.html | 8 +- .../mlx.utils.tree_map_with_path.html | 8 +- .../_autosummary/mlx.utils.tree_reduce.html | 8 +- .../mlx.utils.tree_unflatten.html | 8 +- .../python/_autosummary/stream_class.html | 8 +- docs/build/html/python/array.html | 8 +- docs/build/html/python/data_types.html | 8 +- .../html/python/devices_and_streams.html | 8 +- docs/build/html/python/distributed.html | 8 +- docs/build/html/python/export.html | 8 +- docs/build/html/python/fast.html | 8 +- docs/build/html/python/fft.html | 8 +- docs/build/html/python/linalg.html | 8 +- docs/build/html/python/memory_management.html | 8 +- docs/build/html/python/metal.html | 8 +- docs/build/html/python/nn.html | 8 +- .../python/nn/_autosummary/mlx.nn.ALiBi.html | 8 +- .../nn/_autosummary/mlx.nn.AvgPool1d.html | 8 +- .../nn/_autosummary/mlx.nn.AvgPool2d.html | 8 +- .../nn/_autosummary/mlx.nn.AvgPool3d.html | 8 +- .../nn/_autosummary/mlx.nn.BatchNorm.html | 8 +- .../python/nn/_autosummary/mlx.nn.CELU.html | 8 +- .../python/nn/_autosummary/mlx.nn.Conv1d.html | 8 +- .../python/nn/_autosummary/mlx.nn.Conv2d.html | 8 +- .../python/nn/_autosummary/mlx.nn.Conv3d.html | 8 +- .../_autosummary/mlx.nn.ConvTranspose1d.html | 8 +- .../_autosummary/mlx.nn.ConvTranspose2d.html | 8 +- .../_autosummary/mlx.nn.ConvTranspose3d.html | 8 +- .../nn/_autosummary/mlx.nn.Dropout.html | 8 +- .../nn/_autosummary/mlx.nn.Dropout2d.html | 8 +- .../nn/_autosummary/mlx.nn.Dropout3d.html | 8 +- .../python/nn/_autosummary/mlx.nn.ELU.html | 8 +- .../nn/_autosummary/mlx.nn.Embedding.html | 8 +- .../python/nn/_autosummary/mlx.nn.GELU.html | 8 +- .../python/nn/_autosummary/mlx.nn.GLU.html | 8 +- .../python/nn/_autosummary/mlx.nn.GRU.html | 8 +- .../nn/_autosummary/mlx.nn.GroupNorm.html | 8 +- .../nn/_autosummary/mlx.nn.HardShrink.html | 8 +- .../nn/_autosummary/mlx.nn.HardTanh.html | 8 +- .../nn/_autosummary/mlx.nn.Hardswish.html | 8 +- .../nn/_autosummary/mlx.nn.InstanceNorm.html | 8 +- .../python/nn/_autosummary/mlx.nn.LSTM.html | 8 +- .../nn/_autosummary/mlx.nn.LayerNorm.html | 8 +- .../nn/_autosummary/mlx.nn.LeakyReLU.html | 8 +- .../python/nn/_autosummary/mlx.nn.Linear.html | 8 +- .../nn/_autosummary/mlx.nn.LogSigmoid.html | 8 +- .../nn/_autosummary/mlx.nn.LogSoftmax.html | 8 +- .../nn/_autosummary/mlx.nn.MaxPool1d.html | 8 +- .../nn/_autosummary/mlx.nn.MaxPool2d.html | 8 +- .../nn/_autosummary/mlx.nn.MaxPool3d.html | 8 +- .../python/nn/_autosummary/mlx.nn.Mish.html | 8 +- .../nn/_autosummary/mlx.nn.Module.apply.html | 8 +- .../mlx.nn.Module.apply_to_modules.html | 8 +- .../_autosummary/mlx.nn.Module.children.html | 8 +- .../nn/_autosummary/mlx.nn.Module.eval.html | 8 +- .../mlx.nn.Module.filter_and_map.html | 8 +- .../nn/_autosummary/mlx.nn.Module.freeze.html | 8 +- .../mlx.nn.Module.leaf_modules.html | 8 +- .../mlx.nn.Module.load_weights.html | 8 +- .../_autosummary/mlx.nn.Module.modules.html | 8 +- .../mlx.nn.Module.named_modules.html | 8 +- .../mlx.nn.Module.parameters.html | 8 +- .../mlx.nn.Module.save_weights.html | 8 +- .../_autosummary/mlx.nn.Module.set_dtype.html | 8 +- .../nn/_autosummary/mlx.nn.Module.state.html | 8 +- .../nn/_autosummary/mlx.nn.Module.train.html | 8 +- .../mlx.nn.Module.trainable_parameters.html | 8 +- .../_autosummary/mlx.nn.Module.training.html | 8 +- .../_autosummary/mlx.nn.Module.unfreeze.html | 8 +- .../nn/_autosummary/mlx.nn.Module.update.html | 18 +- .../mlx.nn.Module.update_modules.html | 20 +- .../mlx.nn.MultiHeadAttention.html | 8 +- .../python/nn/_autosummary/mlx.nn.PReLU.html | 8 +- .../mlx.nn.QuantizedEmbedding.html | 8 +- .../_autosummary/mlx.nn.QuantizedLinear.html | 8 +- .../nn/_autosummary/mlx.nn.RMSNorm.html | 8 +- .../python/nn/_autosummary/mlx.nn.RNN.html | 8 +- .../python/nn/_autosummary/mlx.nn.ReLU.html | 8 +- .../python/nn/_autosummary/mlx.nn.ReLU6.html | 8 +- .../python/nn/_autosummary/mlx.nn.RoPE.html | 8 +- .../python/nn/_autosummary/mlx.nn.SELU.html | 8 +- .../nn/_autosummary/mlx.nn.Sequential.html | 8 +- .../python/nn/_autosummary/mlx.nn.SiLU.html | 8 +- .../nn/_autosummary/mlx.nn.Sigmoid.html | 8 +- .../mlx.nn.SinusoidalPositionalEncoding.html | 8 +- .../nn/_autosummary/mlx.nn.Softmax.html | 8 +- .../nn/_autosummary/mlx.nn.Softmin.html | 8 +- .../nn/_autosummary/mlx.nn.Softplus.html | 8 +- .../nn/_autosummary/mlx.nn.Softshrink.html | 8 +- .../nn/_autosummary/mlx.nn.Softsign.html | 8 +- .../python/nn/_autosummary/mlx.nn.Step.html | 8 +- .../python/nn/_autosummary/mlx.nn.Tanh.html | 8 +- .../nn/_autosummary/mlx.nn.Transformer.html | 8 +- .../nn/_autosummary/mlx.nn.Upsample.html | 8 +- .../nn/_autosummary/mlx.nn.init.constant.html | 8 +- .../mlx.nn.init.glorot_normal.html | 8 +- .../mlx.nn.init.glorot_uniform.html | 8 +- .../_autosummary/mlx.nn.init.he_normal.html | 8 +- .../_autosummary/mlx.nn.init.he_uniform.html | 8 +- .../nn/_autosummary/mlx.nn.init.identity.html | 8 +- .../nn/_autosummary/mlx.nn.init.normal.html | 8 +- .../nn/_autosummary/mlx.nn.init.uniform.html | 8 +- .../_autosummary_functions/mlx.nn.celu.html | 8 +- .../nn/_autosummary_functions/mlx.nn.elu.html | 8 +- .../_autosummary_functions/mlx.nn.gelu.html | 8 +- .../mlx.nn.gelu_approx.html | 8 +- .../mlx.nn.gelu_fast_approx.html | 8 +- .../nn/_autosummary_functions/mlx.nn.glu.html | 8 +- .../mlx.nn.hard_shrink.html | 8 +- .../mlx.nn.hard_tanh.html | 8 +- .../mlx.nn.hardswish.html | 8 +- .../mlx.nn.leaky_relu.html | 8 +- .../mlx.nn.log_sigmoid.html | 8 +- .../mlx.nn.log_softmax.html | 8 +- .../mlx.nn.losses.binary_cross_entropy.html | 8 +- .../mlx.nn.losses.cosine_similarity_loss.html | 8 +- .../mlx.nn.losses.cross_entropy.html | 8 +- .../mlx.nn.losses.gaussian_nll_loss.html | 8 +- .../mlx.nn.losses.hinge_loss.html | 8 +- .../mlx.nn.losses.huber_loss.html | 8 +- .../mlx.nn.losses.kl_div_loss.html | 8 +- .../mlx.nn.losses.l1_loss.html | 8 +- .../mlx.nn.losses.log_cosh_loss.html | 8 +- .../mlx.nn.losses.margin_ranking_loss.html | 8 +- .../mlx.nn.losses.mse_loss.html | 8 +- .../mlx.nn.losses.nll_loss.html | 8 +- .../mlx.nn.losses.smooth_l1_loss.html | 8 +- .../mlx.nn.losses.triplet_loss.html | 8 +- .../_autosummary_functions/mlx.nn.mish.html | 8 +- .../_autosummary_functions/mlx.nn.prelu.html | 8 +- .../_autosummary_functions/mlx.nn.relu.html | 8 +- .../_autosummary_functions/mlx.nn.relu6.html | 8 +- .../_autosummary_functions/mlx.nn.selu.html | 8 +- .../mlx.nn.sigmoid.html | 8 +- .../_autosummary_functions/mlx.nn.silu.html | 8 +- .../mlx.nn.softmax.html | 8 +- .../mlx.nn.softmin.html | 8 +- .../mlx.nn.softplus.html | 8 +- .../mlx.nn.softshrink.html | 8 +- .../_autosummary_functions/mlx.nn.step.html | 8 +- .../_autosummary_functions/mlx.nn.tanh.html | 8 +- docs/build/html/python/nn/functions.html | 8 +- docs/build/html/python/nn/init.html | 8 +- docs/build/html/python/nn/layers.html | 8 +- docs/build/html/python/nn/losses.html | 8 +- docs/build/html/python/nn/module.html | 12 +- docs/build/html/python/ops.html | 8 +- docs/build/html/python/optimizers.html | 8 +- .../_autosummary/mlx.optimizers.AdaDelta.html | 8 +- .../mlx.optimizers.Adafactor.html | 8 +- .../_autosummary/mlx.optimizers.Adagrad.html | 8 +- .../_autosummary/mlx.optimizers.Adam.html | 8 +- .../_autosummary/mlx.optimizers.AdamW.html | 8 +- .../_autosummary/mlx.optimizers.Adamax.html | 8 +- .../_autosummary/mlx.optimizers.Lion.html | 8 +- .../mlx.optimizers.MultiOptimizer.html | 8 +- ....optimizers.Optimizer.apply_gradients.html | 8 +- .../mlx.optimizers.Optimizer.init.html | 8 +- .../mlx.optimizers.Optimizer.state.html | 8 +- .../mlx.optimizers.Optimizer.update.html | 8 +- .../_autosummary/mlx.optimizers.RMSprop.html | 8 +- .../_autosummary/mlx.optimizers.SGD.html | 8 +- .../mlx.optimizers.cosine_decay.html | 8 +- .../mlx.optimizers.exponential_decay.html | 8 +- .../mlx.optimizers.join_schedules.html | 8 +- .../mlx.optimizers.linear_schedule.html | 8 +- .../mlx.optimizers.step_decay.html | 8 +- .../python/optimizers/common_optimizers.html | 8 +- .../html/python/optimizers/optimizer.html | 8 +- .../html/python/optimizers/schedulers.html | 8 +- docs/build/html/python/random.html | 8 +- docs/build/html/python/transforms.html | 8 +- docs/build/html/python/tree_utils.html | 8 +- docs/build/html/search.html | 8 +- docs/build/html/searchindex.js | 2 +- docs/build/html/usage/compile.html | 8 +- docs/build/html/usage/distributed.html | 8 +- docs/build/html/usage/export.html | 8 +- .../build/html/usage/function_transforms.html | 8 +- docs/build/html/usage/indexing.html | 20 +- .../html/usage/launching_distributed.html | 8 +- docs/build/html/usage/lazy_evaluation.html | 10 +- docs/build/html/usage/numpy.html | 8 +- docs/build/html/usage/quick_start.html | 10 +- docs/build/html/usage/saving_and_loading.html | 10 +- docs/build/html/usage/unified_memory.html | 8 +- docs/build/html/usage/using_streams.html | 8 +- 533 files changed, 2735 insertions(+), 2574 deletions(-) diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 27f8adb82..7e3c0a72c 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: f0a8f1723eac189223b8c3a08df4cb42 +config: 617e63568890a453837209bb7514fddd tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_sources/dev/custom_metal_kernels.rst b/docs/build/html/_sources/dev/custom_metal_kernels.rst index 3e92f2814..873b1e544 100644 --- a/docs/build/html/_sources/dev/custom_metal_kernels.rst +++ b/docs/build/html/_sources/dev/custom_metal_kernels.rst @@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs. Simple Example -------------- +.. currentmodule:: mlx.core + Let's write a custom kernel that computes ``exp`` elementwise: .. code-block:: python - def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::exp(tmp); - """ + source = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ - kernel = mx.fast.metal_kernel( - name="myexp", - input_names=["inp"], - output_names=["out"], - source=source, - ) + kernel = mx.fast.metal_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source, + ) + + def exp_elementwise(a: mx.array): outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise: b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) +Every time you make a kernel, a new Metal library is created and possibly +JIT compiled. To reduce the overhead from that, build the kernel once with +:func:`fast.metal_kernel` and then use it many times. + .. note:: - We are only required to pass the body of the Metal kernel in ``source``. + Only pass the body of the Metal kernel in ``source``. The function + signature is generated automatically. The full function signature will be generated using: @@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; -Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. -This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. -For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. +Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads +`_ +function. This means we will launch ``mx.prod(grid)`` threads, subdivided into +``threadgroup`` size threadgroups. For optimal performance, each thread group +dimension should be less than or equal to the corresponding grid dimension. -Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. +Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the +generated code for debugging purposes. Using Shape/Strides ------------------- -``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. -This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. -Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims -when indexing. +:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which +is ``True`` by default. This will copy the array inputs if needed +before the kernel is launched to ensure that the memory layout is row +contiguous. Generally this makes writing the kernel easier, since we don't +have to worry about gaps or the ordering of the dims when indexing. -If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each -input array ``a`` if any are present in ``source``. -We can then use MLX's built in indexing utils to fetch the right elements for each thread. +If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes +``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are +present in ``source``. We can then use MLX's built in indexing utils to fetch +the right elements for each thread. -Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: +Let's convert ``myexp`` above to support arbitrarily strided arrays without +relying on a copy from ``ensure_row_contiguous``: .. code-block:: python + + source = """ + uint elem = thread_position_in_grid.x; + // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + // Output arrays are always row contiguous + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp_strided", + input_names=["inp"], + output_names=["out"], + source=source + ) def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included - uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); - T tmp = inp[loc]; - // Output arrays are always row contiguous - out[elem] = metal::exp(tmp); - """ - - kernel = mx.fast.metal_kernel( - name="myexp_strided", - input_names=["inp"], - output_names=["out"], - source=source - ) outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops: .. code-block:: python - def grid_sample_ref(x, grid): - N, H_in, W_in, _ = x.shape - ix = ((grid[..., 0] + 1) * W_in - 1) / 2 - iy = ((grid[..., 1] + 1) * H_in - 1) / 2 + def grid_sample_ref(x, grid): + N, H_in, W_in, _ = x.shape + ix = ((grid[..., 0] + 1) * W_in - 1) / 2 + iy = ((grid[..., 1] + 1) * H_in - 1) / 2 - ix_nw = mx.floor(ix).astype(mx.int32) - iy_nw = mx.floor(iy).astype(mx.int32) + ix_nw = mx.floor(ix).astype(mx.int32) + iy_nw = mx.floor(iy).astype(mx.int32) - ix_ne = ix_nw + 1 - iy_ne = iy_nw + ix_ne = ix_nw + 1 + iy_ne = iy_nw - ix_sw = ix_nw - iy_sw = iy_nw + 1 + ix_sw = ix_nw + iy_sw = iy_nw + 1 - ix_se = ix_nw + 1 - iy_se = iy_nw + 1 + ix_se = ix_nw + 1 + iy_se = iy_nw + 1 - nw = (ix_se - ix) * (iy_se - iy) - ne = (ix - ix_sw) * (iy_sw - iy) - sw = (ix_ne - ix) * (iy - iy_ne) - se = (ix - ix_nw) * (iy - iy_nw) + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) - I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] - I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] - I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] - I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] + I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] + I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] + I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] + I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] - mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) - mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) - mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) - mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) + mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) + mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) + mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) + mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) - I_nw *= mask_nw[..., None] - I_ne *= mask_ne[..., None] - I_sw *= mask_sw[..., None] - I_se *= mask_se[..., None] + I_nw *= mask_nw[..., None] + I_ne *= mask_ne[..., None] + I_sw *= mask_sw[..., None] + I_se *= mask_se[..., None] - output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se + output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se - return output + return output -Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` +Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` to write a fast GPU kernel for both the forward and backward passes. First we'll implement the forward pass as a fused kernel: .. code-block:: python - @mx.custom_function - def grid_sample(x, grid): + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + int gH = grid_shape[1]; + int gW = grid_shape[2]; - assert x.ndim == 4, "`x` must be 4D." - assert grid.ndim == 4, "`grid` must be 4D." + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - B, _, _, C = x.shape - _, gN, gM, D = grid.shape - out_shape = (B, gN, gM, C) + uint grid_idx = elem / C * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - assert D == 2, "Last dim of `grid` must be size 2." + int ix_nw = floor(ix); + int iy_nw = floor(iy); - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - uint grid_idx = elem / C * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int batch_idx = elem / C / gH / gW * b_stride; + int channel_idx = elem % C; + int base_idx = batch_idx + channel_idx; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; + T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; + T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; + T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; + I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; + I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; + I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; + """ - int batch_idx = elem / C / gH / gW * b_stride; - int channel_idx = elem % C; - int base_idx = batch_idx + channel_idx; + kernel = mx.fast.metal_kernel( + name="grid_sample", + input_names=["x", "grid"], + output_names=["out"], + source=source, + ) - T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; - T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; - T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; - T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; + @mx.custom_function + def grid_sample(x, grid): - I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; - I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; - I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; - I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; + assert x.ndim == 4, "`x` must be 4D." + assert grid.ndim == 4, "`grid` must be 4D." - out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; - """ - kernel = mx.fast.metal_kernel( - name="grid_sample", - input_names=["x", "grid"], - output_names=["out"], - source=source, - ) - outputs = kernel( - inputs=[x, grid], - template=[("T", x.dtype)], - output_shapes=[out_shape], - output_dtypes=[x.dtype], - grid=(np.prod(out_shape), 1, 1), - threadgroup=(256, 1, 1), - ) - return outputs[0] + B, _, _, C = x.shape + _, gN, gM, D = grid.shape + out_shape = (B, gN, gM, C) + + assert D == 2, "Last dim of `grid` must be size 2." + + outputs = kernel( + inputs=[x, grid], + template=[("T", x.dtype)], + output_shapes=[out_shape], + output_dtypes=[x.dtype], + grid=(np.prod(out_shape), 1, 1), + threadgroup=(256, 1, 1), + ) + return outputs[0] For a reasonably sized input such as: .. code-block:: python - x.shape = (8, 1024, 1024, 64) - grid.shape = (8, 256, 256, 2) + x.shape = (8, 1024, 1024, 64) + grid.shape = (8, 256, 256, 2) On an M1 Max, we see a big performance improvement: @@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement: Grid Sample VJP --------------- -Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define -its custom vjp transform so MLX can differentiate it. +Since we decorated ``grid_sample`` with :func:`custom_function`, we can now +define its custom vjp transform so MLX can differentiate it. The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so -requires a few extra ``mx.fast.metal_kernel`` features: +requires a few extra :func:`fast.metal_kernel` features: * ``init_value=0`` Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. @@ -299,128 +316,129 @@ We can then implement the backwards pass as follows: .. code-block:: python - @grid_sample.vjp - def grid_sample_vjp(primals, cotangent, _): - x, grid = primals - B, _, _, C = x.shape - _, gN, gM, D = grid.shape + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + // Pad C to the nearest larger simdgroup size multiple + int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; - assert D == 2, "Last dim of `grid` must be size 2." + int gH = grid_shape[1]; + int gW = grid_shape[2]; - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - // Pad C to the nearest larger simdgroup size multiple - int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + uint grid_idx = elem / C_padded * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_nw = floor(ix); + int iy_nw = floor(iy); - uint grid_idx = elem / C_padded * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + int batch_idx = elem / C_padded / gH / gW * b_stride; + int channel_idx = elem % C_padded; + int base_idx = batch_idx + channel_idx; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + T gix = T(0); + T giy = T(0); + if (channel_idx < C) { + int cot_index = elem / C_padded * C + channel_idx; + T cot = cotangent[cot_index]; + if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { + int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); - int batch_idx = elem / C_padded / gH / gW * b_stride; - int channel_idx = elem % C_padded; - int base_idx = batch_idx + channel_idx; + T I_nw = x[offset]; + gix -= I_nw * (iy_se - iy) * cot; + giy -= I_nw * (ix_se - ix) * cot; + } + if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { + int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); - T gix = T(0); - T giy = T(0); - if (channel_idx < C) { - int cot_index = elem / C_padded * C + channel_idx; - T cot = cotangent[cot_index]; - if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { - int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); + T I_ne = x[offset]; + gix += I_ne * (iy_sw - iy) * cot; + giy -= I_ne * (ix - ix_sw) * cot; + } + if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { + int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); - T I_nw = x[offset]; - gix -= I_nw * (iy_se - iy) * cot; - giy -= I_nw * (ix_se - ix) * cot; - } - if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { - int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); + T I_sw = x[offset]; + gix -= I_sw * (iy - iy_ne) * cot; + giy += I_sw * (ix_ne - ix) * cot; + } + if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { + int offset = base_idx + iy_se * h_stride + ix_se * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); - T I_ne = x[offset]; - gix += I_ne * (iy_sw - iy) * cot; - giy -= I_ne * (ix - ix_sw) * cot; - } - if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { - int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); + T I_se = x[offset]; + gix += I_se * (iy - iy_nw) * cot; + giy += I_se * (ix - ix_nw) * cot; + } + } - T I_sw = x[offset]; - gix -= I_sw * (iy - iy_ne) * cot; - giy += I_sw * (ix_ne - ix) * cot; - } - if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { - int offset = base_idx + iy_se * h_stride + ix_se * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); + T gix_mult = W / 2; + T giy_mult = H / 2; - T I_se = x[offset]; - gix += I_se * (iy - iy_nw) * cot; - giy += I_se * (ix - ix_nw) * cot; - } - } + // Reduce across each simdgroup first. + // This is much faster than relying purely on atomics. + gix = simd_sum(gix); + giy = simd_sum(giy); - T gix_mult = W / 2; - T giy_mult = H / 2; + if (thread_index_in_simdgroup == 0) { + atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); + atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); + } + """ + kernel = mx.fast.metal_kernel( + name="grid_sample_grad", + input_names=["x", "grid", "cotangent"], + output_names=["x_grad", "grid_grad"], + source=source, + atomic_outputs=True, + ) - // Reduce across each simdgroup first. - // This is much faster than relying purely on atomics. - gix = simd_sum(gix); - giy = simd_sum(giy); + @grid_sample.vjp + def grid_sample_vjp(primals, cotangent, _): + x, grid = primals + B, _, _, C = x.shape + _, gN, gM, D = grid.shape - if (thread_index_in_simdgroup == 0) { - atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); - atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); - } - """ - kernel = mx.fast.metal_kernel( - name="grid_sample_grad", - input_names=["x", "grid", "cotangent"], - output_names=["x_grad", "grid_grad"], - source=source, - atomic_outputs=True, - ) - # pad the output channels to simd group size - # so that our `simd_sum`s don't overlap. - simdgroup_size = 32 - C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size - grid_size = B * gN * gM * C_padded - outputs = kernel( - inputs=[x, grid, cotangent], - template=[("T", x.dtype)], - output_shapes=[x.shape, grid.shape], - output_dtypes=[x.dtype, x.dtype], - grid=(grid_size, 1, 1), - threadgroup=(256, 1, 1), - init_value=0, - ) - return outputs[0], outputs[1] + assert D == 2, "Last dim of `grid` must be size 2." + + # pad the output channels to simd group size + # so that our `simd_sum`s don't overlap. + simdgroup_size = 32 + C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size + grid_size = B * gN * gM * C_padded + outputs = kernel( + inputs=[x, grid, cotangent], + template=[("T", x.dtype)], + output_shapes=[x.shape, grid.shape], + output_dtypes=[x.dtype, x.dtype], + grid=(grid_size, 1, 1), + threadgroup=(256, 1, 1), + init_value=0, + ) + return outputs[0], outputs[1] There's an even larger speed up for the vjp: diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst index 2aef28f99..03f1c2163 100644 --- a/docs/build/html/_sources/dev/extensions.rst +++ b/docs/build/html/_sources/dev/extensions.rst @@ -397,11 +397,11 @@ below. std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/docs/build/html/_sources/install.rst b/docs/build/html/_sources/install.rst index 059b2cba4..22de94f90 100644 --- a/docs/build/html/_sources/install.rst +++ b/docs/build/html/_sources/install.rst @@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do: conda install conda-forge::mlx +CUDA +^^^^ + +MLX has a CUDA backend which you can use on any Linux platform with CUDA 12 +and SM 7.0 (Volta) and up. To install MLX with CUDA support, run: + +.. code-block:: shell + + pip install mlx-cuda + Troubleshooting ^^^^^^^^^^^^^^^ @@ -65,6 +75,8 @@ Build Requirements Python API ^^^^^^^^^^ +.. _python install: + To build and install the MLX python library from source, first, clone MLX from `its GitHub repo `_: @@ -107,6 +119,8 @@ IDE: C++ API ^^^^^^^ +.. _cpp install: + Currently, MLX must be built and installed from source. Similarly to the python library, to build and install the MLX C++ library start @@ -185,6 +199,7 @@ should point to the path to the built metal library. xcrun -sdk macosx --show-sdk-version + Binary Size Minimization ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the application. Once a kernel is compiled, it will be cached by the system. The Metal kernel cache persists across reboots. +Linux +^^^^^ + +To build from source on Linux (CPU only), install the BLAS and LAPACK headers. +For example on Ubuntu, run the following: + +.. code-block:: shell + + apt-get update -y + apt-get install libblas-dev liblapack-dev liblapacke-dev -y + +From here follow the instructions to install either the :ref:`Python ` or :ref:`C++ ` APIs. + +CUDA +^^^^ + +To build from source on Linux with CUDA, install the BLAS and LAPACK headers +and the CUDA toolkit. For example on Ubuntu, run the following: + +.. code-block:: shell + + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb + dpkg -i cuda-keyring_1.1-1_all.deb + apt-get update -y + apt-get -y install cuda-toolkit-12-9 + apt-get install libblas-dev liblapack-dev liblapacke-dev -y + + +When building either the Python or C++ APIs make sure to pass the cmake flag +``MLX_BUILD_CUDA=ON``. For example, to build the Python API run: + +.. code-block:: shell + + CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" + +To build the C++ package run: + +.. code-block:: shell + + mkdir -p build && cd build + cmake .. -DMLX_BUILD_CUDA=ON && make -j + + Troubleshooting ^^^^^^^^^^^^^^^ diff --git a/docs/build/html/_sources/usage/indexing.rst b/docs/build/html/_sources/usage/indexing.rst index c74e357fa..dcbc84c1b 100644 --- a/docs/build/html/_sources/usage/indexing.rst +++ b/docs/build/html/_sources/usage/indexing.rst @@ -107,6 +107,16 @@ same array: >>> a array([1, 2, 0], dtype=int32) + +Note, unlike NumPy, updates to the same location are nondeterministic: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> a[[0, 0]] = mx.array([4, 5]) + +The first element of ``a`` could be ``4`` or ``5``. + Transformations of functions which use in-place updates are allowed and work as expected. For example: diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index 2bf8eee98..610b22e80 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,5 +1,5 @@ const DOCUMENTATION_OPTIONS = { - VERSION: '0.26.1', + VERSION: '0.26.2', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html index 629ab50e0..ab71788aa 100644 --- a/docs/build/html/cpp/ops.html +++ b/docs/build/html/cpp/ops.html @@ -8,7 +8,7 @@ - Operations — MLX 0.26.1 documentation + Operations — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -137,8 +137,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + diff --git a/docs/build/html/dev/custom_metal_kernels.html b/docs/build/html/dev/custom_metal_kernels.html index dbff63ffa..14931cf22 100644 --- a/docs/build/html/dev/custom_metal_kernels.html +++ b/docs/build/html/dev/custom_metal_kernels.html @@ -8,7 +8,7 @@ - Custom Metal Kernels — MLX 0.26.1 documentation + Custom Metal Kernels — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -137,8 +137,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + @@ -926,19 +926,20 @@ document.write(`

Simple Example#

Let’s write a custom kernel that computes exp elementwise:

-
def exp_elementwise(a: mx.array):
-    source = """
-        uint elem = thread_position_in_grid.x;
-        T tmp = inp[elem];
-        out[elem] = metal::exp(tmp);
-    """
+
source = """
+    uint elem = thread_position_in_grid.x;
+    T tmp = inp[elem];
+    out[elem] = metal::exp(tmp);
+"""
 
-    kernel = mx.fast.metal_kernel(
-        name="myexp",
-        input_names=["inp"],
-        output_names=["out"],
-        source=source,
-    )
+kernel = mx.fast.metal_kernel(
+    name="myexp",
+    input_names=["inp"],
+    output_names=["out"],
+    source=source,
+)
+
+def exp_elementwise(a: mx.array):
     outputs = kernel(
         inputs=[a],
         template=[("T", mx.float32)],
@@ -954,9 +955,13 @@ document.write(`
 assert mx.allclose(b, mx.exp(a))
 
+

Every time you make a kernel, a new Metal library is created and possibly +JIT compiled. To reduce the overhead from that, build the kernel once with +fast.metal_kernel() and then use it many times.

Note

-

We are only required to pass the body of the Metal kernel in source.

+

Only pass the body of the Metal kernel in source. The function +signature is generated automatically.

The full function signature will be generated using:

-

Note: grid and threadgroup are parameters to the Metal dispatchThreads function. -This means we will launch mx.prod(grid) threads, subdivided into threadgroup size threadgroups. -For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.

-

Passing verbose=True to mx.fast.metal_kernel.__call__ will print the generated code for debugging purposes.

+

Note: grid and threadgroup are parameters to the Metal dispatchThreads +function. This means we will launch mx.prod(grid) threads, subdivided into +threadgroup size threadgroups. For optimal performance, each thread group +dimension should be less than or equal to the corresponding grid dimension.

+

Passing verbose=True to ast.metal_kernel.__call__() will print the +generated code for debugging purposes.

Using Shape/Strides#

-

mx.fast.metal_kernel supports an argument ensure_row_contiguous which is True by default. -This will copy the mx.array inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. -Generally this makes writing the kernel easier, since we don’t have to worry about gaps or the ordering of the dims -when indexing.

-

If we want to avoid this copy, metal_kernel automatically passes a_shape, a_strides and a_ndim for each -input array a if any are present in source. -We can then use MLX’s built in indexing utils to fetch the right elements for each thread.

-

Let’s convert myexp above to support arbitrarily strided arrays without relying on a copy from ensure_row_contiguous:

-
def exp_elementwise(a: mx.array):
-    source = """
-        uint elem = thread_position_in_grid.x;
-        // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
-        uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
-        T tmp = inp[loc];
-        // Output arrays are always row contiguous
-        out[elem] = metal::exp(tmp);
-    """
+

fast.metal_kernel() supports an argument ensure_row_contiguous which +is True by default. This will copy the array inputs if needed +before the kernel is launched to ensure that the memory layout is row +contiguous. Generally this makes writing the kernel easier, since we don’t +have to worry about gaps or the ordering of the dims when indexing.

+

If we want to avoid this copy, fast.metal_kernel() automatically passes +a_shape, a_strides and a_ndim for each input array a if any are +present in source. We can then use MLX’s built in indexing utils to fetch +the right elements for each thread.

+

Let’s convert myexp above to support arbitrarily strided arrays without +relying on a copy from ensure_row_contiguous:

+
source = """
+    uint elem = thread_position_in_grid.x;
+    // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
+    uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
+    T tmp = inp[loc];
+    // Output arrays are always row contiguous
+    out[elem] = metal::exp(tmp);
+"""
 
-    kernel = mx.fast.metal_kernel(
-        name="myexp_strided",
-        input_names=["inp"],
-        output_names=["out"],
-        source=source
-    )
+kernel = mx.fast.metal_kernel(
+    name="myexp_strided",
+    input_names=["inp"],
+    output_names=["out"],
+    source=source
+)
+
+def exp_elementwise(a: mx.array):
     outputs = kernel(
         inputs=[a],
         template=[("T", mx.float32)],
@@ -1100,10 +1111,67 @@ We can then use MLX’s built in indexing utils to fetch the right elements for
     return output
 
-

Now let’s use mx.custom_function together with mx.fast.metal_kernel +

Now let’s use custom_function() together with fast.metal_kernel() to write a fast GPU kernel for both the forward and backward passes.

First we’ll implement the forward pass as a fused kernel:

-
@mx.custom_function
+
source = """
+    uint elem = thread_position_in_grid.x;
+    int H = x_shape[1];
+    int W = x_shape[2];
+    int C = x_shape[3];
+    int gH = grid_shape[1];
+    int gW = grid_shape[2];
+
+    int w_stride = C;
+    int h_stride = W * w_stride;
+    int b_stride = H * h_stride;
+
+    uint grid_idx = elem / C * 2;
+    float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
+    float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
+
+    int ix_nw = floor(ix);
+    int iy_nw = floor(iy);
+
+    int ix_ne = ix_nw + 1;
+    int iy_ne = iy_nw;
+
+    int ix_sw = ix_nw;
+    int iy_sw = iy_nw + 1;
+
+    int ix_se = ix_nw + 1;
+    int iy_se = iy_nw + 1;
+
+    T nw = (ix_se - ix)    * (iy_se - iy);
+    T ne = (ix    - ix_sw) * (iy_sw - iy);
+    T sw = (ix_ne - ix)    * (iy    - iy_ne);
+    T se = (ix    - ix_nw) * (iy    - iy_nw);
+
+    int batch_idx = elem / C / gH / gW * b_stride;
+    int channel_idx = elem % C;
+    int base_idx = batch_idx + channel_idx;
+
+    T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
+    T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
+    T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
+    T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
+
+    I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
+    I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
+    I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
+    I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
+
+    out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
+"""
+
+kernel = mx.fast.metal_kernel(
+    name="grid_sample",
+    input_names=["x", "grid"],
+    output_names=["out"],
+    source=source,
+)
+
+@mx.custom_function
 def grid_sample(x, grid):
 
     assert x.ndim == 4, "`x` must be 4D."
@@ -1115,61 +1183,6 @@ to write a fast GPU kernel for both the forward and backward passes.

assert D == 2, "Last dim of `grid` must be size 2." - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - int gH = grid_shape[1]; - int gW = grid_shape[2]; - - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; - - uint grid_idx = elem / C * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - - int ix_nw = floor(ix); - int iy_nw = floor(iy); - - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; - - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; - - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; - - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); - - int batch_idx = elem / C / gH / gW * b_stride; - int channel_idx = elem % C; - int base_idx = batch_idx + channel_idx; - - T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; - T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; - T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; - T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; - - I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; - I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; - I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; - I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; - - out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; - """ - kernel = mx.fast.metal_kernel( - name="grid_sample", - input_names=["x", "grid"], - output_names=["out"], - source=source, - ) outputs = kernel( inputs=[x, grid], template=[("T", x.dtype)], @@ -1191,10 +1204,10 @@ to write a fast GPU kernel for both the forward and backward passes.

Grid Sample VJP#

-

Since we decorated grid_sample with mx.custom_function, we can now define -its custom vjp transform so MLX can differentiate it.

+

Since we decorated grid_sample with custom_function(), we can now +define its custom vjp transform so MLX can differentiate it.

The backwards pass requires atomically updating x_grad/grid_grad and so -requires a few extra mx.fast.metal_kernel features:

+requires a few extra fast.metal_kernel() features:

  • init_value=0

    Initialize all of the kernel’s outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.

    @@ -1210,7 +1223,107 @@ See section 6.15 of the
    @grid_sample.vjp
    +
    source = """
    +    uint elem = thread_position_in_grid.x;
    +    int H = x_shape[1];
    +    int W = x_shape[2];
    +    int C = x_shape[3];
    +    // Pad C to the nearest larger simdgroup size multiple
    +    int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
    +
    +    int gH = grid_shape[1];
    +    int gW = grid_shape[2];
    +
    +    int w_stride = C;
    +    int h_stride = W * w_stride;
    +    int b_stride = H * h_stride;
    +
    +    uint grid_idx = elem / C_padded * 2;
    +    float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
    +    float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
    +
    +    int ix_nw = floor(ix);
    +    int iy_nw = floor(iy);
    +
    +    int ix_ne = ix_nw + 1;
    +    int iy_ne = iy_nw;
    +
    +    int ix_sw = ix_nw;
    +    int iy_sw = iy_nw + 1;
    +
    +    int ix_se = ix_nw + 1;
    +    int iy_se = iy_nw + 1;
    +
    +    T nw = (ix_se - ix)    * (iy_se - iy);
    +    T ne = (ix    - ix_sw) * (iy_sw - iy);
    +    T sw = (ix_ne - ix)    * (iy    - iy_ne);
    +    T se = (ix    - ix_nw) * (iy    - iy_nw);
    +
    +    int batch_idx = elem / C_padded / gH / gW * b_stride;
    +    int channel_idx = elem % C_padded;
    +    int base_idx = batch_idx + channel_idx;
    +
    +    T gix = T(0);
    +    T giy = T(0);
    +    if (channel_idx < C) {
    +        int cot_index = elem / C_padded * C + channel_idx;
    +        T cot = cotangent[cot_index];
    +        if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
    +            int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
    +            atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
    +
    +            T I_nw = x[offset];
    +            gix -= I_nw * (iy_se - iy) * cot;
    +            giy -= I_nw * (ix_se - ix) * cot;
    +        }
    +        if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
    +            int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
    +            atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
    +
    +            T I_ne = x[offset];
    +            gix += I_ne * (iy_sw - iy) * cot;
    +            giy -= I_ne * (ix - ix_sw) * cot;
    +        }
    +        if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
    +            int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
    +            atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
    +
    +            T I_sw = x[offset];
    +            gix -= I_sw * (iy - iy_ne) * cot;
    +            giy += I_sw * (ix_ne - ix) * cot;
    +        }
    +        if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
    +            int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
    +            atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
    +
    +            T I_se = x[offset];
    +            gix += I_se * (iy - iy_nw) * cot;
    +            giy += I_se * (ix - ix_nw) * cot;
    +        }
    +    }
    +
    +    T gix_mult = W / 2;
    +    T giy_mult = H / 2;
    +
    +    // Reduce across each simdgroup first.
    +    // This is much faster than relying purely on atomics.
    +    gix = simd_sum(gix);
    +    giy = simd_sum(giy);
    +
    +    if (thread_index_in_simdgroup == 0) {
    +        atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
    +        atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
    +    }
    +"""
    +kernel = mx.fast.metal_kernel(
    +    name="grid_sample_grad",
    +    input_names=["x", "grid", "cotangent"],
    +    output_names=["x_grad", "grid_grad"],
    +    source=source,
    +    atomic_outputs=True,
    +)
    +
    +@grid_sample.vjp
     def grid_sample_vjp(primals, cotangent, _):
         x, grid = primals
         B, _, _, C = x.shape
    @@ -1218,105 +1331,6 @@ See section 6.15 of the assert D == 2, "Last dim of `grid` must be size 2."
     
    -    source = """
    -        uint elem = thread_position_in_grid.x;
    -        int H = x_shape[1];
    -        int W = x_shape[2];
    -        int C = x_shape[3];
    -        // Pad C to the nearest larger simdgroup size multiple
    -        int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
    -
    -        int gH = grid_shape[1];
    -        int gW = grid_shape[2];
    -
    -        int w_stride = C;
    -        int h_stride = W * w_stride;
    -        int b_stride = H * h_stride;
    -
    -        uint grid_idx = elem / C_padded * 2;
    -        float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
    -        float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
    -
    -        int ix_nw = floor(ix);
    -        int iy_nw = floor(iy);
    -
    -        int ix_ne = ix_nw + 1;
    -        int iy_ne = iy_nw;
    -
    -        int ix_sw = ix_nw;
    -        int iy_sw = iy_nw + 1;
    -
    -        int ix_se = ix_nw + 1;
    -        int iy_se = iy_nw + 1;
    -
    -        T nw = (ix_se - ix)    * (iy_se - iy);
    -        T ne = (ix    - ix_sw) * (iy_sw - iy);
    -        T sw = (ix_ne - ix)    * (iy    - iy_ne);
    -        T se = (ix    - ix_nw) * (iy    - iy_nw);
    -
    -        int batch_idx = elem / C_padded / gH / gW * b_stride;
    -        int channel_idx = elem % C_padded;
    -        int base_idx = batch_idx + channel_idx;
    -
    -        T gix = T(0);
    -        T giy = T(0);
    -        if (channel_idx < C) {
    -            int cot_index = elem / C_padded * C + channel_idx;
    -            T cot = cotangent[cot_index];
    -            if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
    -                int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
    -                atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
    -
    -                T I_nw = x[offset];
    -                gix -= I_nw * (iy_se - iy) * cot;
    -                giy -= I_nw * (ix_se - ix) * cot;
    -            }
    -            if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
    -                int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
    -                atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
    -
    -                T I_ne = x[offset];
    -                gix += I_ne * (iy_sw - iy) * cot;
    -                giy -= I_ne * (ix - ix_sw) * cot;
    -            }
    -            if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
    -                int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
    -                atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
    -
    -                T I_sw = x[offset];
    -                gix -= I_sw * (iy - iy_ne) * cot;
    -                giy += I_sw * (ix_ne - ix) * cot;
    -            }
    -            if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
    -                int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
    -                atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
    -
    -                T I_se = x[offset];
    -                gix += I_se * (iy - iy_nw) * cot;
    -                giy += I_se * (ix - ix_nw) * cot;
    -            }
    -        }
    -
    -        T gix_mult = W / 2;
    -        T giy_mult = H / 2;
    -
    -        // Reduce across each simdgroup first.
    -        // This is much faster than relying purely on atomics.
    -        gix = simd_sum(gix);
    -        giy = simd_sum(giy);
    -
    -        if (thread_index_in_simdgroup == 0) {
    -            atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
    -            atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
    -        }
    -    """
    -    kernel = mx.fast.metal_kernel(
    -        name="grid_sample_grad",
    -        input_names=["x", "grid", "cotangent"],
    -        output_names=["x_grad", "grid_grad"],
    -        source=source,
    -        atomic_outputs=True,
    -    )
         # pad the output channels to simd group size
         # so that our `simd_sum`s don't overlap.
         simdgroup_size = 32
    diff --git a/docs/build/html/dev/extensions.html b/docs/build/html/dev/extensions.html
    index 2826dc74a..eb5044b28 100644
    --- a/docs/build/html/dev/extensions.html
    +++ b/docs/build/html/dev/extensions.html
    @@ -8,7 +8,7 @@
         
         
     
    -    Custom Extensions in MLX — MLX 0.26.1 documentation
    +    Custom Extensions in MLX — MLX 0.26.2 documentation
       
       
       
    @@ -36,7 +36,7 @@
     
       
     
    -    
    +    
         
         
         
    @@ -137,8 +137,8 @@
           
         
         
    -    MLX 0.26.1 documentation - Home
    -    
    +    MLX 0.26.2 documentation - Home
    +    
       
       
     
    @@ -1305,11 +1305,11 @@ below.

    std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/docs/build/html/dev/metal_debugger.html b/docs/build/html/dev/metal_debugger.html index 883040f0a..b1237669f 100644 --- a/docs/build/html/dev/metal_debugger.html +++ b/docs/build/html/dev/metal_debugger.html @@ -8,7 +8,7 @@ - Metal Debugger — MLX 0.26.1 documentation + Metal Debugger — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -137,8 +137,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home +
    diff --git a/docs/build/html/dev/mlx_in_cpp.html b/docs/build/html/dev/mlx_in_cpp.html index fdf509479..55fcff1d8 100644 --- a/docs/build/html/dev/mlx_in_cpp.html +++ b/docs/build/html/dev/mlx_in_cpp.html @@ -8,7 +8,7 @@ - Using MLX in C++ — MLX 0.26.1 documentation + Using MLX in C++ — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -136,8 +136,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home +
    diff --git a/docs/build/html/examples/linear_regression.html b/docs/build/html/examples/linear_regression.html index e20c7766e..74e9eafc7 100644 --- a/docs/build/html/examples/linear_regression.html +++ b/docs/build/html/examples/linear_regression.html @@ -8,7 +8,7 @@ - Linear Regression — MLX 0.26.1 documentation + Linear Regression — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -137,8 +137,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + diff --git a/docs/build/html/examples/llama-inference.html b/docs/build/html/examples/llama-inference.html index 7d78ffb81..d9dfd9444 100644 --- a/docs/build/html/examples/llama-inference.html +++ b/docs/build/html/examples/llama-inference.html @@ -8,7 +8,7 @@ - LLM inference — MLX 0.26.1 documentation + LLM inference — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -137,8 +137,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + diff --git a/docs/build/html/examples/mlp.html b/docs/build/html/examples/mlp.html index 238d8dec9..570831dac 100644 --- a/docs/build/html/examples/mlp.html +++ b/docs/build/html/examples/mlp.html @@ -8,7 +8,7 @@ - Multi-Layer Perceptron — MLX 0.26.1 documentation + Multi-Layer Perceptron — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -137,8 +137,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + diff --git a/docs/build/html/genindex.html b/docs/build/html/genindex.html index 6596c8206..c73daa9bf 100644 --- a/docs/build/html/genindex.html +++ b/docs/build/html/genindex.html @@ -7,7 +7,7 @@ - Index — MLX 0.26.1 documentation + Index — MLX 0.26.2 documentation @@ -35,7 +35,7 @@ - + @@ -136,8 +136,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + diff --git a/docs/build/html/index.html b/docs/build/html/index.html index a92a7e854..d66d72608 100644 --- a/docs/build/html/index.html +++ b/docs/build/html/index.html @@ -8,7 +8,7 @@ - MLX — MLX 0.26.1 documentation + MLX — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -139,8 +139,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + diff --git a/docs/build/html/install.html b/docs/build/html/install.html index 9a6271f82..85618d1f5 100644 --- a/docs/build/html/install.html +++ b/docs/build/html/install.html @@ -8,7 +8,7 @@ - Build and Install — MLX 0.26.1 documentation + Build and Install — MLX 0.26.2 documentation @@ -36,7 +36,7 @@ - + @@ -137,8 +137,8 @@ - MLX 0.26.1 documentation - Home - + MLX 0.26.2 documentation - Home + @@ -906,6 +906,7 @@ document.write(`
  • -
  • Troubleshooting