From f77d99b2852fd66cd5d457f4e0b604e975df3648 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 11 Apr 2024 17:33:33 -0700 Subject: [PATCH] docs update --- docs/build/html/.buildinfo | 2 +- docs/build/html/_sources/dev/extensions.rst | 475 ++++----- .../html/_sources/dev/metal_debugger.rst | 45 +- .../python/_autosummary/mlx.core.expm1.rst | 6 + .../python/_autosummary/mlx.core.meshgrid.rst | 6 + .../mlx.core.metal.start_capture.rst | 6 + .../mlx.core.metal.stop_capture.rst | 6 + .../mlx.core.random.multivariate_normal.rst | 6 + .../python/_autosummary/mlx.core.std.rst | 6 + docs/build/html/_sources/python/metal.rst | 4 +- docs/build/html/_sources/python/ops.rst | 7 +- docs/build/html/_sources/python/random.rst | 1 + .../html/_sources/usage/lazy_evaluation.rst | 2 +- .../html/_static/documentation_options.js | 2 +- docs/build/html/cpp/ops.html | 14 +- docs/build/html/dev/extensions.html | 468 +++++---- docs/build/html/dev/metal_debugger.html | 54 +- .../html/examples/linear_regression.html | 14 +- docs/build/html/examples/llama-inference.html | 14 +- docs/build/html/examples/mlp.html | 14 +- docs/build/html/genindex.html | 34 +- docs/build/html/index.html | 14 +- docs/build/html/install.html | 14 +- docs/build/html/objects.inv | Bin 9447 -> 9586 bytes .../python/_autosummary/mlx.core.Device.html | 14 +- .../python/_autosummary/mlx.core.Dtype.html | 14 +- .../_autosummary/mlx.core.DtypeCategory.html | 14 +- .../python/_autosummary/mlx.core.Stream.html | 14 +- .../python/_autosummary/mlx.core.abs.html | 14 +- .../python/_autosummary/mlx.core.add.html | 14 +- .../python/_autosummary/mlx.core.all.html | 14 +- .../_autosummary/mlx.core.allclose.html | 14 +- .../python/_autosummary/mlx.core.any.html | 14 +- .../python/_autosummary/mlx.core.arange.html | 14 +- .../python/_autosummary/mlx.core.arccos.html | 14 +- .../python/_autosummary/mlx.core.arccosh.html | 14 +- .../python/_autosummary/mlx.core.arcsin.html | 14 +- .../python/_autosummary/mlx.core.arcsinh.html | 14 +- .../python/_autosummary/mlx.core.arctan.html | 14 +- .../python/_autosummary/mlx.core.arctanh.html | 14 +- .../python/_autosummary/mlx.core.argmax.html | 14 +- .../python/_autosummary/mlx.core.argmin.html | 14 +- .../_autosummary/mlx.core.argpartition.html | 14 +- .../python/_autosummary/mlx.core.argsort.html | 14 +- .../python/_autosummary/mlx.core.array.T.html | 14 +- .../_autosummary/mlx.core.array.abs.html | 14 +- .../_autosummary/mlx.core.array.all.html | 14 +- .../_autosummary/mlx.core.array.any.html | 14 +- .../_autosummary/mlx.core.array.argmax.html | 14 +- .../_autosummary/mlx.core.array.argmin.html | 14 +- .../_autosummary/mlx.core.array.astype.html | 14 +- .../_autosummary/mlx.core.array.at.html | 14 +- .../_autosummary/mlx.core.array.cos.html | 14 +- .../_autosummary/mlx.core.array.cummax.html | 14 +- .../_autosummary/mlx.core.array.cummin.html | 14 +- .../_autosummary/mlx.core.array.cumprod.html | 14 +- .../_autosummary/mlx.core.array.cumsum.html | 14 +- .../_autosummary/mlx.core.array.diag.html | 14 +- .../_autosummary/mlx.core.array.diagonal.html | 14 +- .../_autosummary/mlx.core.array.dtype.html | 14 +- .../_autosummary/mlx.core.array.exp.html | 14 +- .../_autosummary/mlx.core.array.flatten.html | 14 +- .../python/_autosummary/mlx.core.array.html | 14 +- .../_autosummary/mlx.core.array.item.html | 14 +- .../_autosummary/mlx.core.array.itemsize.html | 14 +- .../_autosummary/mlx.core.array.log.html | 14 +- .../_autosummary/mlx.core.array.log10.html | 14 +- .../_autosummary/mlx.core.array.log1p.html | 14 +- .../_autosummary/mlx.core.array.log2.html | 14 +- .../mlx.core.array.logsumexp.html | 14 +- .../_autosummary/mlx.core.array.max.html | 14 +- .../_autosummary/mlx.core.array.mean.html | 14 +- .../_autosummary/mlx.core.array.min.html | 14 +- .../_autosummary/mlx.core.array.moveaxis.html | 14 +- .../_autosummary/mlx.core.array.nbytes.html | 14 +- .../_autosummary/mlx.core.array.ndim.html | 14 +- .../_autosummary/mlx.core.array.prod.html | 14 +- .../mlx.core.array.reciprocal.html | 14 +- .../_autosummary/mlx.core.array.reshape.html | 14 +- .../_autosummary/mlx.core.array.round.html | 14 +- .../_autosummary/mlx.core.array.rsqrt.html | 14 +- .../_autosummary/mlx.core.array.shape.html | 14 +- .../_autosummary/mlx.core.array.sin.html | 14 +- .../_autosummary/mlx.core.array.size.html | 14 +- .../_autosummary/mlx.core.array.split.html | 14 +- .../_autosummary/mlx.core.array.sqrt.html | 14 +- .../_autosummary/mlx.core.array.square.html | 14 +- .../_autosummary/mlx.core.array.squeeze.html | 14 +- .../_autosummary/mlx.core.array.sum.html | 14 +- .../_autosummary/mlx.core.array.swapaxes.html | 14 +- .../_autosummary/mlx.core.array.tolist.html | 14 +- .../mlx.core.array.transpose.html | 14 +- .../_autosummary/mlx.core.array.var.html | 14 +- .../_autosummary/mlx.core.array_equal.html | 14 +- .../_autosummary/mlx.core.atleast_1d.html | 14 +- .../_autosummary/mlx.core.atleast_2d.html | 14 +- .../_autosummary/mlx.core.atleast_3d.html | 14 +- .../_autosummary/mlx.core.broadcast_to.html | 14 +- .../python/_autosummary/mlx.core.ceil.html | 14 +- .../python/_autosummary/mlx.core.clip.html | 14 +- .../python/_autosummary/mlx.core.compile.html | 14 +- .../_autosummary/mlx.core.concatenate.html | 14 +- .../python/_autosummary/mlx.core.conv1d.html | 14 +- .../python/_autosummary/mlx.core.conv2d.html | 14 +- .../_autosummary/mlx.core.conv_general.html | 14 +- .../_autosummary/mlx.core.convolve.html | 14 +- .../python/_autosummary/mlx.core.cos.html | 14 +- .../python/_autosummary/mlx.core.cosh.html | 14 +- .../python/_autosummary/mlx.core.cummax.html | 14 +- .../python/_autosummary/mlx.core.cummin.html | 14 +- .../python/_autosummary/mlx.core.cumprod.html | 14 +- .../python/_autosummary/mlx.core.cumsum.html | 14 +- .../_autosummary/mlx.core.default_device.html | 14 +- .../_autosummary/mlx.core.default_stream.html | 14 +- .../_autosummary/mlx.core.dequantize.html | 14 +- .../python/_autosummary/mlx.core.diag.html | 14 +- .../_autosummary/mlx.core.diagonal.html | 14 +- .../mlx.core.disable_compile.html | 14 +- .../python/_autosummary/mlx.core.divide.html | 14 +- .../python/_autosummary/mlx.core.divmod.html | 14 +- .../_autosummary/mlx.core.enable_compile.html | 14 +- .../python/_autosummary/mlx.core.equal.html | 14 +- .../python/_autosummary/mlx.core.erf.html | 14 +- .../python/_autosummary/mlx.core.erfinv.html | 14 +- .../python/_autosummary/mlx.core.eval.html | 14 +- .../python/_autosummary/mlx.core.exp.html | 20 +- .../_autosummary/mlx.core.expand_dims.html | 20 +- .../python/_autosummary/mlx.core.expm1.html | 901 +++++++++++++++++ .../python/_autosummary/mlx.core.eye.html | 14 +- .../mlx.core.fast.layer_norm.html | 14 +- .../_autosummary/mlx.core.fast.rms_norm.html | 14 +- .../_autosummary/mlx.core.fast.rope.html | 14 +- ...ore.fast.scaled_dot_product_attention.html | 14 +- .../python/_autosummary/mlx.core.fft.fft.html | 14 +- .../_autosummary/mlx.core.fft.fft2.html | 14 +- .../_autosummary/mlx.core.fft.fftn.html | 14 +- .../_autosummary/mlx.core.fft.ifft.html | 14 +- .../_autosummary/mlx.core.fft.ifft2.html | 14 +- .../_autosummary/mlx.core.fft.ifftn.html | 14 +- .../_autosummary/mlx.core.fft.irfft.html | 14 +- .../_autosummary/mlx.core.fft.irfft2.html | 14 +- .../_autosummary/mlx.core.fft.irfftn.html | 14 +- .../_autosummary/mlx.core.fft.rfft.html | 14 +- .../_autosummary/mlx.core.fft.rfft2.html | 14 +- .../_autosummary/mlx.core.fft.rfftn.html | 14 +- .../python/_autosummary/mlx.core.flatten.html | 14 +- .../python/_autosummary/mlx.core.floor.html | 14 +- .../_autosummary/mlx.core.floor_divide.html | 14 +- .../python/_autosummary/mlx.core.full.html | 14 +- .../python/_autosummary/mlx.core.grad.html | 14 +- .../python/_autosummary/mlx.core.greater.html | 14 +- .../_autosummary/mlx.core.greater_equal.html | 14 +- .../_autosummary/mlx.core.identity.html | 14 +- .../python/_autosummary/mlx.core.inner.html | 14 +- .../python/_autosummary/mlx.core.isclose.html | 14 +- .../python/_autosummary/mlx.core.isinf.html | 14 +- .../python/_autosummary/mlx.core.isnan.html | 14 +- .../_autosummary/mlx.core.isneginf.html | 14 +- .../_autosummary/mlx.core.isposinf.html | 14 +- .../_autosummary/mlx.core.issubdtype.html | 14 +- .../python/_autosummary/mlx.core.jvp.html | 14 +- .../python/_autosummary/mlx.core.less.html | 14 +- .../_autosummary/mlx.core.less_equal.html | 14 +- .../_autosummary/mlx.core.linalg.norm.html | 14 +- .../_autosummary/mlx.core.linalg.qr.html | 14 +- .../_autosummary/mlx.core.linspace.html | 14 +- .../python/_autosummary/mlx.core.load.html | 14 +- .../python/_autosummary/mlx.core.log.html | 14 +- .../python/_autosummary/mlx.core.log10.html | 14 +- .../python/_autosummary/mlx.core.log1p.html | 14 +- .../python/_autosummary/mlx.core.log2.html | 14 +- .../_autosummary/mlx.core.logaddexp.html | 14 +- .../_autosummary/mlx.core.logical_and.html | 14 +- .../_autosummary/mlx.core.logical_not.html | 14 +- .../_autosummary/mlx.core.logical_or.html | 14 +- .../_autosummary/mlx.core.logsumexp.html | 14 +- .../python/_autosummary/mlx.core.matmul.html | 14 +- .../python/_autosummary/mlx.core.max.html | 14 +- .../python/_autosummary/mlx.core.maximum.html | 14 +- .../python/_autosummary/mlx.core.mean.html | 20 +- .../_autosummary/mlx.core.meshgrid.html | 907 +++++++++++++++++ .../mlx.core.metal.get_active_memory.html | 14 +- .../mlx.core.metal.get_cache_memory.html | 14 +- .../mlx.core.metal.get_peak_memory.html | 14 +- .../mlx.core.metal.is_available.html | 14 +- .../mlx.core.metal.set_cache_limit.html | 20 +- .../mlx.core.metal.set_memory_limit.html | 14 +- .../mlx.core.metal.start_capture.html | 901 +++++++++++++++++ .../mlx.core.metal.stop_capture.html | 889 +++++++++++++++++ .../python/_autosummary/mlx.core.min.html | 20 +- .../python/_autosummary/mlx.core.minimum.html | 14 +- .../_autosummary/mlx.core.moveaxis.html | 14 +- .../_autosummary/mlx.core.multiply.html | 14 +- .../_autosummary/mlx.core.negative.html | 14 +- .../_autosummary/mlx.core.new_stream.html | 14 +- .../python/_autosummary/mlx.core.ones.html | 14 +- .../_autosummary/mlx.core.ones_like.html | 14 +- .../python/_autosummary/mlx.core.outer.html | 14 +- .../python/_autosummary/mlx.core.pad.html | 14 +- .../_autosummary/mlx.core.partition.html | 14 +- .../python/_autosummary/mlx.core.prod.html | 14 +- .../_autosummary/mlx.core.quantize.html | 14 +- .../mlx.core.quantized_matmul.html | 14 +- .../mlx.core.random.bernoulli.html | 14 +- .../mlx.core.random.categorical.html | 14 +- .../_autosummary/mlx.core.random.gumbel.html | 14 +- .../_autosummary/mlx.core.random.key.html | 14 +- .../mlx.core.random.multivariate_normal.html | 914 ++++++++++++++++++ .../_autosummary/mlx.core.random.normal.html | 20 +- .../_autosummary/mlx.core.random.randint.html | 20 +- .../_autosummary/mlx.core.random.seed.html | 14 +- .../_autosummary/mlx.core.random.split.html | 14 +- .../mlx.core.random.truncated_normal.html | 14 +- .../_autosummary/mlx.core.random.uniform.html | 14 +- .../_autosummary/mlx.core.reciprocal.html | 14 +- .../python/_autosummary/mlx.core.repeat.html | 14 +- .../python/_autosummary/mlx.core.reshape.html | 14 +- .../python/_autosummary/mlx.core.round.html | 14 +- .../python/_autosummary/mlx.core.rsqrt.html | 14 +- .../python/_autosummary/mlx.core.save.html | 14 +- .../_autosummary/mlx.core.save_gguf.html | 14 +- .../mlx.core.save_safetensors.html | 14 +- .../python/_autosummary/mlx.core.savez.html | 14 +- .../mlx.core.savez_compressed.html | 14 +- .../mlx.core.set_default_device.html | 14 +- .../mlx.core.set_default_stream.html | 14 +- .../python/_autosummary/mlx.core.sigmoid.html | 14 +- .../python/_autosummary/mlx.core.sign.html | 14 +- .../python/_autosummary/mlx.core.sin.html | 14 +- .../python/_autosummary/mlx.core.sinh.html | 14 +- .../python/_autosummary/mlx.core.softmax.html | 14 +- .../python/_autosummary/mlx.core.sort.html | 14 +- .../python/_autosummary/mlx.core.split.html | 14 +- .../python/_autosummary/mlx.core.sqrt.html | 14 +- .../python/_autosummary/mlx.core.square.html | 14 +- .../python/_autosummary/mlx.core.squeeze.html | 14 +- .../python/_autosummary/mlx.core.stack.html | 20 +- .../python/_autosummary/mlx.core.std.html | 909 +++++++++++++++++ .../_autosummary/mlx.core.stop_gradient.html | 20 +- .../_autosummary/mlx.core.subtract.html | 14 +- .../python/_autosummary/mlx.core.sum.html | 14 +- .../_autosummary/mlx.core.swapaxes.html | 14 +- .../python/_autosummary/mlx.core.take.html | 14 +- .../mlx.core.take_along_axis.html | 14 +- .../python/_autosummary/mlx.core.tan.html | 14 +- .../python/_autosummary/mlx.core.tanh.html | 14 +- .../_autosummary/mlx.core.tensordot.html | 14 +- .../python/_autosummary/mlx.core.tile.html | 14 +- .../python/_autosummary/mlx.core.topk.html | 14 +- .../_autosummary/mlx.core.transpose.html | 14 +- .../python/_autosummary/mlx.core.tri.html | 14 +- .../python/_autosummary/mlx.core.tril.html | 14 +- .../python/_autosummary/mlx.core.triu.html | 14 +- .../_autosummary/mlx.core.value_and_grad.html | 14 +- .../python/_autosummary/mlx.core.var.html | 14 +- .../python/_autosummary/mlx.core.vjp.html | 14 +- .../python/_autosummary/mlx.core.vmap.html | 14 +- .../python/_autosummary/mlx.core.where.html | 14 +- .../python/_autosummary/mlx.core.zeros.html | 14 +- .../_autosummary/mlx.core.zeros_like.html | 14 +- .../_autosummary/mlx.nn.value_and_grad.html | 14 +- .../_autosummary/mlx.utils.tree_flatten.html | 14 +- .../_autosummary/mlx.utils.tree_map.html | 14 +- .../mlx.utils.tree_unflatten.html | 14 +- .../python/_autosummary/stream_class.html | 14 +- docs/build/html/python/array.html | 14 +- docs/build/html/python/data_types.html | 14 +- .../html/python/devices_and_streams.html | 14 +- docs/build/html/python/fast.html | 14 +- docs/build/html/python/fft.html | 14 +- docs/build/html/python/linalg.html | 14 +- docs/build/html/python/metal.html | 20 +- docs/build/html/python/nn.html | 20 +- .../python/nn/_autosummary/mlx.nn.ALiBi.html | 14 +- .../nn/_autosummary/mlx.nn.AvgPool1d.html | 14 +- .../nn/_autosummary/mlx.nn.AvgPool2d.html | 14 +- .../nn/_autosummary/mlx.nn.BatchNorm.html | 14 +- .../python/nn/_autosummary/mlx.nn.Conv1d.html | 14 +- .../python/nn/_autosummary/mlx.nn.Conv2d.html | 14 +- .../nn/_autosummary/mlx.nn.Dropout.html | 14 +- .../nn/_autosummary/mlx.nn.Dropout2d.html | 14 +- .../nn/_autosummary/mlx.nn.Dropout3d.html | 14 +- .../nn/_autosummary/mlx.nn.Embedding.html | 14 +- .../python/nn/_autosummary/mlx.nn.GELU.html | 14 +- .../python/nn/_autosummary/mlx.nn.GRU.html | 14 +- .../nn/_autosummary/mlx.nn.GroupNorm.html | 14 +- .../nn/_autosummary/mlx.nn.InstanceNorm.html | 14 +- .../python/nn/_autosummary/mlx.nn.LSTM.html | 14 +- .../nn/_autosummary/mlx.nn.LayerNorm.html | 14 +- .../python/nn/_autosummary/mlx.nn.Linear.html | 14 +- .../nn/_autosummary/mlx.nn.MaxPool1d.html | 14 +- .../nn/_autosummary/mlx.nn.MaxPool2d.html | 14 +- .../python/nn/_autosummary/mlx.nn.Mish.html | 14 +- .../nn/_autosummary/mlx.nn.Module.apply.html | 14 +- .../mlx.nn.Module.apply_to_modules.html | 14 +- .../_autosummary/mlx.nn.Module.children.html | 14 +- .../nn/_autosummary/mlx.nn.Module.eval.html | 14 +- .../mlx.nn.Module.filter_and_map.html | 14 +- .../nn/_autosummary/mlx.nn.Module.freeze.html | 14 +- .../mlx.nn.Module.leaf_modules.html | 14 +- .../mlx.nn.Module.load_weights.html | 14 +- .../_autosummary/mlx.nn.Module.modules.html | 14 +- .../mlx.nn.Module.named_modules.html | 14 +- .../mlx.nn.Module.parameters.html | 14 +- .../mlx.nn.Module.save_weights.html | 14 +- .../_autosummary/mlx.nn.Module.set_dtype.html | 14 +- .../nn/_autosummary/mlx.nn.Module.state.html | 14 +- .../nn/_autosummary/mlx.nn.Module.train.html | 14 +- .../mlx.nn.Module.trainable_parameters.html | 14 +- .../_autosummary/mlx.nn.Module.training.html | 14 +- .../_autosummary/mlx.nn.Module.unfreeze.html | 14 +- .../nn/_autosummary/mlx.nn.Module.update.html | 14 +- .../mlx.nn.Module.update_modules.html | 14 +- .../mlx.nn.MultiHeadAttention.html | 14 +- .../python/nn/_autosummary/mlx.nn.PReLU.html | 14 +- .../_autosummary/mlx.nn.QuantizedLinear.html | 14 +- .../nn/_autosummary/mlx.nn.RMSNorm.html | 14 +- .../python/nn/_autosummary/mlx.nn.RNN.html | 14 +- .../python/nn/_autosummary/mlx.nn.ReLU.html | 14 +- .../python/nn/_autosummary/mlx.nn.RoPE.html | 14 +- .../python/nn/_autosummary/mlx.nn.SELU.html | 14 +- .../nn/_autosummary/mlx.nn.Sequential.html | 14 +- .../python/nn/_autosummary/mlx.nn.SiLU.html | 14 +- .../mlx.nn.SinusoidalPositionalEncoding.html | 14 +- .../nn/_autosummary/mlx.nn.Softshrink.html | 14 +- .../python/nn/_autosummary/mlx.nn.Step.html | 14 +- .../nn/_autosummary/mlx.nn.Transformer.html | 14 +- .../nn/_autosummary/mlx.nn.Upsample.html | 31 +- .../nn/_autosummary/mlx.nn.init.constant.html | 14 +- .../mlx.nn.init.glorot_normal.html | 14 +- .../mlx.nn.init.glorot_uniform.html | 14 +- .../_autosummary/mlx.nn.init.he_normal.html | 14 +- .../_autosummary/mlx.nn.init.he_uniform.html | 14 +- .../nn/_autosummary/mlx.nn.init.identity.html | 14 +- .../nn/_autosummary/mlx.nn.init.normal.html | 14 +- .../nn/_autosummary/mlx.nn.init.uniform.html | 14 +- .../nn/_autosummary_functions/mlx.nn.elu.html | 14 +- .../_autosummary_functions/mlx.nn.gelu.html | 14 +- .../mlx.nn.gelu_approx.html | 14 +- .../mlx.nn.gelu_fast_approx.html | 14 +- .../nn/_autosummary_functions/mlx.nn.glu.html | 14 +- .../mlx.nn.hardswish.html | 14 +- .../mlx.nn.leaky_relu.html | 14 +- .../mlx.nn.log_sigmoid.html | 14 +- .../mlx.nn.log_softmax.html | 14 +- .../mlx.nn.losses.binary_cross_entropy.html | 14 +- .../mlx.nn.losses.cosine_similarity_loss.html | 14 +- .../mlx.nn.losses.cross_entropy.html | 14 +- .../mlx.nn.losses.gaussian_nll_loss.html | 14 +- .../mlx.nn.losses.hinge_loss.html | 14 +- .../mlx.nn.losses.huber_loss.html | 14 +- .../mlx.nn.losses.kl_div_loss.html | 14 +- .../mlx.nn.losses.l1_loss.html | 14 +- .../mlx.nn.losses.log_cosh_loss.html | 14 +- .../mlx.nn.losses.margin_ranking_loss.html | 14 +- .../mlx.nn.losses.mse_loss.html | 14 +- .../mlx.nn.losses.nll_loss.html | 14 +- .../mlx.nn.losses.smooth_l1_loss.html | 14 +- .../mlx.nn.losses.triplet_loss.html | 14 +- .../_autosummary_functions/mlx.nn.mish.html | 14 +- .../_autosummary_functions/mlx.nn.prelu.html | 14 +- .../_autosummary_functions/mlx.nn.relu.html | 14 +- .../_autosummary_functions/mlx.nn.relu6.html | 14 +- .../_autosummary_functions/mlx.nn.selu.html | 14 +- .../mlx.nn.sigmoid.html | 14 +- .../_autosummary_functions/mlx.nn.silu.html | 14 +- .../mlx.nn.softmax.html | 14 +- .../mlx.nn.softplus.html | 14 +- .../mlx.nn.softshrink.html | 14 +- .../_autosummary_functions/mlx.nn.step.html | 14 +- .../_autosummary_functions/mlx.nn.tanh.html | 14 +- docs/build/html/python/nn/functions.html | 14 +- docs/build/html/python/nn/init.html | 14 +- docs/build/html/python/nn/layers.html | 14 +- docs/build/html/python/nn/losses.html | 14 +- docs/build/html/python/nn/module.html | 14 +- docs/build/html/python/ops.html | 125 +-- docs/build/html/python/optimizers.html | 14 +- .../_autosummary/mlx.optimizers.AdaDelta.html | 14 +- .../mlx.optimizers.Adafactor.html | 14 +- .../_autosummary/mlx.optimizers.Adagrad.html | 14 +- .../_autosummary/mlx.optimizers.Adam.html | 14 +- .../_autosummary/mlx.optimizers.AdamW.html | 14 +- .../_autosummary/mlx.optimizers.Adamax.html | 14 +- .../_autosummary/mlx.optimizers.Lion.html | 14 +- ....optimizers.Optimizer.apply_gradients.html | 14 +- .../mlx.optimizers.Optimizer.init.html | 14 +- .../mlx.optimizers.Optimizer.state.html | 14 +- .../mlx.optimizers.Optimizer.update.html | 14 +- .../_autosummary/mlx.optimizers.RMSprop.html | 14 +- .../_autosummary/mlx.optimizers.SGD.html | 14 +- .../mlx.optimizers.cosine_decay.html | 14 +- .../mlx.optimizers.exponential_decay.html | 14 +- .../mlx.optimizers.join_schedules.html | 14 +- .../mlx.optimizers.linear_schedule.html | 14 +- .../mlx.optimizers.step_decay.html | 14 +- .../python/optimizers/common_optimizers.html | 14 +- .../html/python/optimizers/optimizer.html | 14 +- .../html/python/optimizers/schedulers.html | 14 +- docs/build/html/python/random.html | 27 +- docs/build/html/python/transforms.html | 14 +- docs/build/html/python/tree_utils.html | 14 +- docs/build/html/search.html | 14 +- docs/build/html/searchindex.js | 2 +- docs/build/html/usage/compile.html | 14 +- .../build/html/usage/function_transforms.html | 14 +- docs/build/html/usage/indexing.html | 14 +- docs/build/html/usage/lazy_evaluation.html | 16 +- docs/build/html/usage/numpy.html | 14 +- docs/build/html/usage/quick_start.html | 14 +- docs/build/html/usage/saving_and_loading.html | 14 +- docs/build/html/usage/unified_memory.html | 14 +- docs/build/html/usage/using_streams.html | 14 +- 413 files changed, 9992 insertions(+), 2202 deletions(-) create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.expm1.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.meshgrid.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.metal.start_capture.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.metal.stop_capture.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.random.multivariate_normal.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.std.rst create mode 100644 docs/build/html/python/_autosummary/mlx.core.expm1.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.meshgrid.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.metal.start_capture.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.metal.stop_capture.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.random.multivariate_normal.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.std.html diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 5ac4f4ae3..ba59ddb17 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 0404ad38f8b7b0d4bcdda401dd97c652 +config: 130317d5cc5607ddb0ed4187804765bc tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst index 9198548a4..acf41a773 100644 --- a/docs/build/html/_sources/dev/extensions.rst +++ b/docs/build/html/_sources/dev/extensions.rst @@ -1,24 +1,16 @@ Developer Documentation ======================= -MLX provides a open and flexible backend to which users may add operations -and specialized implementations without much hassle. While the library supplies -efficient operations that can be used and composed for any number of -applications, there may arise cases where new functionalities or highly -optimized implementations are needed. For such cases, you may design and -implement your own operations that link to and build on top of :mod:`mlx.core`. -We will introduce the inner-workings of MLX and go over a simple example to -learn the steps involved in adding new operations to MLX with your own CPU -and GPU implementations. +You can extend MLX with custom operations on the CPU or GPU. This guide +explains how to do that with a simple example. Introducing the Example ----------------------- -Let's say that you would like an operation that takes in two arrays, -``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta`` -respectively, and then adds them together to get the result -``z = alpha * x + beta * y``. Well, you can very easily do that by just -writing out a function as follows: +Let's say you would like an operation that takes in two arrays, ``x`` and +``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively, +and then adds them together to get the result ``z = alpha * x + beta * y``. +You can do that in MLX directly: .. code-block:: python @@ -27,44 +19,35 @@ writing out a function as follows: def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: return alpha * x + beta * y -This function performs that operation while leaving the implementations and -differentiation to MLX. +This function performs that operation while leaving the implementation and +function transformations to MLX. -However, you work with vector math libraries often and realize that the -``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``. -You would really like the part of your applications that does this operation -on the CPU to be very fast - so you decide that you want it to rely on the -``axpby`` routine provided by the Accelerate_ framework. Continuing to impose -our assumptions on to you, let's also assume that you want to learn how to add -your own implementation for the gradients of your new operation while going -over the ins-and-outs of the MLX framework. +However you may need to customize the underlying implementation, perhaps to +make it faster or for custom differentiation. In this tutorial we will go +through adding custom extensions. It will cover: -Well, what a coincidence! You are in the right place. Over the course of this -example, we will learn: - -* The structure of the MLX library from the frontend API to the backend implementations. -* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed). -* How to implement your own GPU implementation using metal. -* How to add your own ``vjp`` and ``jvp``. -* How to build your implementations, link them to MLX, and bind them to python. +* The structure of the MLX library. +* Implementing a CPU operation that redirects to Accelerate_ when appropriate. +* Implementing a GPU operation using metal. +* Adding the ``vjp`` and ``jvp`` function transformation. +* Building a custom extension and binding it to python. Operations and Primitives ------------------------- -In one sentence, operations in MLX build the computation graph, and primitives -provide the rules for evaluation and transformations of said graph. Let's start -by discussing operations in more detail. +Operations in MLX build the computation graph. Primitives provide the rules for +evaluating and transforming the graph. Let's start by discussing operations in +more detail. Operations ^^^^^^^^^^^ -Operations are the frontend functions that operate on arrays. They are defined -in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these -operations in the Python API (:ref:`ops`). +Operations are the front-end functions that operate on arrays. They are defined +in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them. -We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``, -and two scalars, ``alpha`` and ``beta``. This is how we would define it in the -C++ API: +We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and +``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in +C++: .. code-block:: C++ @@ -83,10 +66,7 @@ C++ API: StreamOrDevice s = {} // Stream on which to schedule the operation ); - -This operation itself can call other operations within it if needed. So, the -simplest way to go about implementing this operation would be do so in terms -of existing operations. +The simplest way to this operation is in terms of existing operations: .. code-block:: C++ @@ -100,25 +80,23 @@ of existing operations. // Scale x and y on the provided stream auto ax = multiply(array(alpha), x, s); auto by = multiply(array(beta), y, s); - + // Add and return return add(ax, by, s); } -However, as we discussed earlier, this is not our goal. The operations themselves -do not contain the implementations that act on the data, nor do they contain the -rules of transformations. Rather, they are an easy to use interface that build -on top of the building blocks we call :class:`Primitive`. +The operations themselves do not contain the implementations that act on the +data, nor do they contain the rules of transformations. Rather, they are an +easy to use interface that use :class:`Primitive` building blocks. Primitives ^^^^^^^^^^^ -A :class:`Primitive` is part of the computation graph of an :class:`array`. It -defines how to create an output given a set of input :class:`array` . Further, -a :class:`Primitive` is a class that contains rules on how it is evaluated -on the CPU or GPU, and how it acts under transformations such as ``vjp`` and -``jvp``. These words on their own can be a bit abstract, so lets take a step -back and go to our example to give ourselves a more concrete image. +A :class:`Primitive` is part of the computation graph of an :class:`array`. It +defines how to create outputs arrays given a input arrays. Further, a +:class:`Primitive` has methods to run on the CPU or GPU and for function +transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be +more concrete: .. code-block:: C++ @@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image. * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ - void eval_cpu(const std::vector& inputs, array& out) override; - void eval_gpu(const std::vector& inputs, array& out) override; + void eval_cpu( + const std::vector& inputs, + std::vector& outputs) override; + void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override; /** The Jacobian-vector product. */ - array jvp( + std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; @@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image. std::vector vjp( const std::vector& primals, const array& cotan, - const std::vector& argnums) override; + const std::vector& argnums, + const std::vector& outputs) override; /** * The primitive must know how to vectorize itself across @@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image. * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ - std::pair vmap( + virtual std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; @@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image. void eval(const std::vector& inputs, array& out); }; -The :class:`Axpby` class derives from the base :class:`Primitive` class and -follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and -``beta`` as parameters. It then provides implementations of how the array ``out`` -is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and -:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in -:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`. +The :class:`Axpby` class derives from the base :class:`Primitive` class. The +:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides +implementations of how the output array is produced given the inputs through +:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules +of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and +:meth:`Axpby::vmap`. -Using the Primitives -^^^^^^^^^^^^^^^^^^^^^ +Using the Primitive +^^^^^^^^^^^^^^^^^^^ -Operations can use this :class:`Primitive` to add a new :class:`array` to -the computation graph. An :class:`array` can be constructed by providing its -data type, shape, the :class:`Primitive` that computes it, and the -:class:`array` inputs that are passed to the primitive. +Operations can use this :class:`Primitive` to add a new :class:`array` to the +computation graph. An :class:`array` can be constructed by providing its data +type, shape, the :class:`Primitive` that computes it, and the :class:`array` +inputs that are passed to the primitive. -Let's re-implement our operation now in terms of our :class:`Axpby` primitive. +Let's reimplement our operation now in terms of our :class:`Axpby` primitive. .. code-block:: C++ @@ -238,27 +221,26 @@ This operation now handles the following: Implementing the Primitive -------------------------- -No computation happens when we call the operation alone. In effect, the -operation only builds the computation graph. When we evaluate the output -array, MLX schedules the execution of the computation graph, and calls -:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the -stream/device specified by the user. +No computation happens when we call the operation alone. The operation only +builds the computation graph. When we evaluate the output array, MLX schedules +the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or +:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user. .. warning:: When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called, no memory has been allocated for the output array. It falls on the implementation - of these functions to allocate memory as needed + of these functions to allocate memory as needed. -Implementing the CPU Backend +Implementing the CPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Let's start by trying to implement a naive and generic version of -:meth:`Axpby::eval_cpu`. We declared this as a private member function of -:class:`Axpby` earlier called :meth:`Axpby::eval`. +Let's start by implementing a naive and generic version of +:meth:`Axpby::eval_cpu`. We declared this as a private member function of +:class:`Axpby` earlier called :meth:`Axpby::eval`. -Our naive method will go over each element of the output array, find the -corresponding input elements of ``x`` and ``y`` and perform the operation -pointwise. This is captured in the templated function :meth:`axpby_impl`. +Our naive method will go over each element of the output array, find the +corresponding input elements of ``x`` and ``y`` and perform the operation +point-wise. This is captured in the templated function :meth:`axpby_impl`. .. code-block:: C++ @@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`. } } -Now, we would like our implementation to be able to do this pointwise operation -for all incoming floating point arrays. Accordingly, we add dispatches for -``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error -if we encounter an unexpected type. +Our implementation should work for all incoming floating point arrays. +Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and +``complex64``. We throw an error if we encounter an unexpected type. .. code-block:: C++ /** Fall back implementation for evaluation on CPU */ - void Axpby::eval(const std::vector& inputs, array& out) { - // Check the inputs (registered in the op while constructing the out array) - assert(inputs.size() == 2); + void Axpby::eval( + const std::vector& inputs, + const std::vector& outputs) { auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Dispatch to the correct dtype if (out.dtype() == float32) { @@ -321,28 +303,26 @@ if we encounter an unexpected type. return axpby_impl(x, y, out, alpha_, beta_); } else { throw std::runtime_error( - "Axpby is only supported for floating point types."); + "[Axpby] Only supports floating point types."); } } -We have a fallback implementation! Now, to do what we are really here to do. -Remember we wanted to use the ``axpby`` routine provided by the Accelerate_ -framework? Well, there are 3 complications to keep in mind: +This is good as a fallback implementation. We can use the ``axpby`` routine +provided by the Accelerate_ framework for a faster implementation in certain +cases: #. Accelerate does not provide implementations of ``axpby`` for half precision - floats. We can only direct to it for ``float32`` types -#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements - have fixed strides between them. Possibly due to broadcasts and transposes, - we aren't guaranteed that the inputs fit this requirement. We can - only direct to Accelerate if both ``x`` and ``y`` are row contiguous or - column contiguous. -#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace. - MLX expects to write out the answer to a new array. We must copy the elements - of ``y`` into the output array and use that as an input to ``axpby`` + floats. We can only use it for ``float32`` types. +#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all + elements have fixed strides between them. We only direct to Accelerate + if both ``x`` and ``y`` are row contiguous or column contiguous. +#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place. + MLX expects to write the output to a new array. We must copy the elements + of ``y`` into the output and use that as an input to ``axpby``. -Let's write out an implementation that uses Accelerate in the right conditions. -It must simply allocate data for the output, copy elements of ``y`` into it, -and then call the :meth:`catlas_saxpby` from accelerate. +Let's write an implementation that uses Accelerate in the right conditions. +It allocates data for the output, copies ``y`` into it, and then calls the +:func:`catlas_saxpby` from accelerate. .. code-block:: C++ @@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate. // Accelerate library provides catlas_saxpby which does // Y = (alpha * X) + (beta * Y) in place // To use it, we first copy the data in y over to the output array - - // This specialization requires both x and y be contiguous in the same mode - // i.e: corresponding linear indices in both point to corresponding elements - // The data in the output array is allocated to match the strides in y - // such that x, y, and out are contiguous in the same mode and - // no transposition is needed - out.set_data( - allocator::malloc_or_wait(y.data_size() * out.itemsize()), - y.data_size(), - y.strides(), - y.flags()); + out.set_data(allocator::malloc_or_wait(out.nbytes())); // We then copy over the elements using the contiguous vector specialization copy_inplace(y, out, CopyType::Vector); @@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate. /* INCY = */ 1); } -Great! But what about the inputs that do not fit the criteria for accelerate? -Luckily, we can always just direct back to :meth:`Axpby::eval`. - -With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`. +For inputs that do not fit the criteria for accelerate, we fall back to +:meth:`Axpby::eval`. With this in mind, let's finish our +:meth:`Axpby::eval_cpu`. .. code-block:: C++ /** Evaluate primitive on CPU using accelerate specializations */ - void Axpby::eval_cpu(const std::vector& inputs, array& out) { + void Axpby::eval_cpu( + const std::vector& inputs, + const std::vector& outputs) { assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Accelerate specialization for contiguous single precision float arrays if (out.dtype() == float32 && @@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`. return; } - // Fall back to common backend if specializations are not available - eval(inputs, out); + // Fall back to common back-end if specializations are not available + eval(inputs, outputs); } -We have now hit a milestone! Just this much is enough to run the operation -:meth:`axpby` on a CPU stream! +Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If +you do not plan on running the operation on the GPU or using transforms on +computation graphs that contain :class:`Axpby`, you can stop implementing the +primitive here and enjoy the speed-ups you get from the Accelerate library. -If you do not plan on running the operation on the GPU or using transforms on -computation graphs that contain :class:`Axpby`, you can stop implementing the -primitive here and enjoy the speed-ups you get from the Accelerate library. - -Implementing the GPU Backend +Implementing the GPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Apple silicon devices address their GPUs using the Metal_ shading language, and -all GPU kernels in MLX are written using metal. +Apple silicon devices address their GPUs using the Metal_ shading language, and +GPU kernels in MLX are written using Metal. .. note:: - Here are some helpful resources if you are new to metal! + Here are some helpful resources if you are new to Metal: * A walkthrough of the metal compute pipeline: `Metal Example`_ * Documentation for metal shading language: `Metal Specification`_ * Using metal from C++: `Metal-cpp`_ -Let's keep the GPU algorithm simple. We will launch exactly as many threads -as there are elements in the output. Each thread will pick the element it needs -from ``x`` and ``y``, do the pointwise operation, and then update its assigned -element in the output. +Let's keep the GPU kernel simple. We will launch exactly as many threads as +there are elements in the output. Each thread will pick the element it needs +from ``x`` and ``y``, do the point-wise operation, and update its assigned +element in the output. .. code-block:: C++ @@ -457,15 +427,14 @@ element in the output. // Convert linear indices to offsets in array auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim); - + // Do the operation and update the output - out[index] = + out[index] = static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; } We then need to instantiate this template for all floating point types and give -each instantiation a unique host name so we can identify the right kernel for -each data type. +each instantiation a unique host name so we can identify it. .. code-block:: C++ @@ -488,29 +457,21 @@ each data type. instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(complex64, complex64_t); -This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we -will see later in :ref:`Building with CMake`. In the following example, we -assume that the library ``mlx_ext.metallib`` will always be co-located with -the executable/ shared-library calling the :meth:`register_library` function. -The :meth:`register_library` function takes the library's name and potential -path (or in this case, a function that can produce the path of the metal -library) and tries to load that library if it hasn't already been registered -by the relevant static :class:`mlx::core::metal::Device` object. This is why, -it is important to package your C++ library with the metal library. We will -go over this process in more detail later. - -The logic to determine the kernel, set the inputs, resolve the grid dimensions -and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown +The logic to determine the kernel, set the inputs, resolve the grid dimensions, +and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown below. .. code-block:: C++ /** Evaluate primitive on GPU */ - void Axpby::eval_gpu(const std::vector& inputs, array& out) { + void Axpby::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { // Prepare inputs assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Each primitive carries the stream it should execute on // and each stream carries its device identifiers @@ -518,10 +479,10 @@ below. // We get the needed metal device using the stream auto& d = metal::device(s.device); - // Allocate output memory + // Allocate output memory out.set_data(allocator::malloc_or_wait(out.nbytes())); - // Resolve name of kernel (corresponds to axpby.metal) + // Resolve name of kernel std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); @@ -552,7 +513,7 @@ below. compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder->setBytes(&beta_, sizeof(float), 4); - // Encode shape, strides and ndim + // Encode shape, strides and ndim compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); @@ -575,28 +536,25 @@ below. We can now call the :meth:`axpby` operation on both the CPU and the GPU! -A few things to note about MLX and metal before moving on. MLX keeps track -of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder` -to give us the active metal compute command encoder instead of building a -new one and calling :meth:`compute_encoder->end_encoding` at the end. -MLX keeps adding kernels (compute pipelines) to the active command encoder -until some specified limit is hit or the compute encoder needs to be flushed -for synchronization. MLX also handles enqueuing and committing the associated -command buffers as needed. We suggest taking a deeper dive into -:class:`metal::Device` if you would like to study this routine further. +A few things to note about MLX and Metal before moving on. MLX keeps track of +the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is +associated. We rely on :meth:`d.get_command_encoder` to give us the active +metal compute command encoder instead of building a new one and calling +:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute +pipelines) to the active command buffer until some specified limit is hit or +the command buffer needs to be flushed for synchronization. Primitive Transforms ^^^^^^^^^^^^^^^^^^^^^ -Now that we have come this far, let's also learn how to add implementations to -transformations in a :class:`Primitive`. These transformations can be built on -top of our operations, including the one we just defined now. Which then gives -us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. +Next, let's add implementations for transformations in a :class:`Primitive`. +These transformations can be built on top of other operations, including the +one we just defined: .. code-block:: C++ /** The Jacobian-vector product. */ - array Axpby::jvp( + std::vector Axpby::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. if (argnums.size() > 1) { auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale_arr = array(scale, tangents[0].dtype()); - return multiply(scale_arr, tangents[0], stream()); + return {multiply(scale_arr, tangents[0], stream())}; } // If, argnums = {0, 1}, we take contributions from both // which gives us jvp = tangent_x * alpha + tangent_y * beta else { - return axpby(tangents[0], tangents[1], alpha_, beta_, stream()); + return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; } } @@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. /** The vector-Jacobian product. */ std::vector Axpby::vjp( const std::vector& primals, - const array& cotan, - const std::vector& argnums) { + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& /* unused */) { // Reverse mode diff std::vector vjps; for (auto arg : argnums) { auto scale = arg == 0 ? alpha_ : beta_; - auto scale_arr = array(scale, cotan.dtype()); - vjps.push_back(multiply(scale_arr, cotan, stream())); + auto scale_arr = array(scale, cotangents[0].dtype()); + vjps.push_back(multiply(scale_arr, cotangents[0], stream())); } return vjps; } -Finally, you need not have a transformation fully defined to start using your -own :class:`Primitive`. +Note, a transformation does not need to be fully defined to start using +the :class:`Primitive`. .. code-block:: C++ /** Vectorize primitive along given axis */ - std::pair Axpby::vmap( + std::pair, std::vector> Axpby::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("Axpby has no vmap implementation."); + throw std::runtime_error("[Axpby] vmap not implemented."); } Building and Binding -------------------- -Let's look at the overall directory structure first. +Let's look at the overall directory structure first. | extensions | ├── axpby @@ -666,40 +625,39 @@ Let's look at the overall directory structure first. | └── setup.py * ``extensions/axpby/`` defines the C++ extension library -* ``extensions/mlx_sample_extensions`` sets out the structure for the - associated python package -* ``extensions/bindings.cpp`` provides python bindings for our operation -* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and - python bindings +* ``extensions/mlx_sample_extensions`` sets out the structure for the + associated Python package +* ``extensions/bindings.cpp`` provides Python bindings for our operation +* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and + Python bindings * ``extensions/setup.py`` holds the ``setuptools`` rules to build and install - the python package + the Python package Binding to Python ^^^^^^^^^^^^^^^^^^ -We use PyBind11_ to build a Python API for the C++ library. Since bindings for +We use nanobind_ to build a Python API for the C++ library. Since bindings for components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are -already provided, adding our :meth:`axpby` is simple! +already provided, adding our :meth:`axpby` is simple. .. code-block:: C++ - PYBIND11_MODULE(mlx_sample_extensions, m) { - m.doc() = "Sample C++ and metal extensions for MLX"; + NB_MODULE(_ext, m) { + m.doc() = "Sample extension for MLX"; m.def( "axpby", &axpby, "x"_a, "y"_a, - py::pos_only(), "alpha"_a, "beta"_a, - py::kw_only(), - "stream"_a = py::none(), - R"pbdoc( + nb::kw_only(), + "stream"_a = nb::none(), + R"( Scale and sum two vectors element-wise ``z = alpha * x + beta * y`` - + Follows numpy style broadcasting between ``x`` and ``y`` Inputs are upcasted to floats if needed @@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple! Returns: array: ``alpha * x + beta * y`` - )pbdoc"); + )"); } -Most of the complexity in the above example comes from additional bells and +Most of the complexity in the above example comes from additional bells and whistles such as the literal names and doc-strings. .. warning:: - :mod:`mlx.core` needs to be imported before importing - :mod:`mlx_sample_extensions` as defined by the pybind11 module above to - ensure that the casters for :mod:`mlx.core` components like + :mod:`mlx.core` must be imported before importing + :mod:`mlx_sample_extensions` as defined by the nanobind module above to + ensure that the casters for :mod:`mlx.core` components like :class:`mlx.core.array` are available. .. _Building with CMake: @@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings. Building with CMake ^^^^^^^^^^^^^^^^^^^^ -Building the C++ extension library itself is simple, it only requires that you -``find_package(MLX CONFIG)`` and then link it to your library. +Building the C++ extension library only requires that you ``find_package(MLX +CONFIG)`` and then link it to your library. .. code-block:: cmake @@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you # Link to mlx target_link_libraries(mlx_ext PUBLIC mlx) -We also need to build the attached metal library. For convenience, we provide a -:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given -sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and -automatically imported with MLX package). +We also need to build the attached Metal library. For convenience, we provide a +:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given +sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and +automatically imported with MLX package). -Here is what that looks like in practice! +Here is what that looks like in practice: .. code-block:: cmake @@ -779,27 +737,29 @@ Here is what that looks like in practice! endif() -Finally, we build the Pybind11_ bindings +Finally, we build the nanobind_ bindings .. code-block:: cmake - pybind11_add_module( - mlx_sample_extensions - ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp + nanobind_add_module( + _ext + NB_STATIC STABLE_ABI LTO NOMINSIZE + NB_DOMAIN mlx + ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp ) - target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) + target_link_libraries(_ext PRIVATE mlx_ext) if(BUILD_SHARED_LIBS) - target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) + target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path) endif() Building with ``setuptools`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Once we have set out the CMake build rules as described above, we can use the -build utilities defined in :mod:`mlx.extension` for a simple build process. +build utilities defined in :mod:`mlx.extension`: -.. code-block:: python +.. code-block:: python from mlx import extension from setuptools import setup @@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process. name="mlx_sample_extensions", version="0.0.0", description="Sample C++ and Metal extensions for MLX primitives.", - ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], + ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")], cmdclass={"build_ext": extension.CMakeBuild}, - packages = ["mlx_sample_extensions"], - package_dir = {"": "mlx_sample_extensions"}, - package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]}, + packages=["mlx_sample_extensions"], + package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, + extras_require={"dev":[]}, zip_safe=False, - python_requires=">=3.7", + python_requires=">=3.8", ) .. note:: We treat ``extensions/mlx_sample_extensions`` as the package directory even though it only contains a ``__init__.py`` to ensure the following: - - * :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions` - * The C++ extension library and the metal library are co-located with the python - bindings and copied together if the package is installed -You can build inplace for development using + * :mod:`mlx.core` must be imported before importing :mod:`_ext` + * The C++ extension library and the metal library are co-located with the python + bindings and copied together if the package is installed + +To build the package, first install the build dependencies with ``pip install +-r requirements.txt``. You can then build inplace for development using ``python setup.py build_ext -j8 --inplace`` (in ``extensions/``) -This will result in a directory structure as follows: +This results in the directory structure: | extensions | ├── mlx_sample_extensions | │ ├── __init__.py | │ ├── libmlx_ext.dylib # C++ extension library | │ ├── mlx_ext.metallib # Metal library -| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding +| │ └── _ext.cpython-3x-darwin.so # Python Binding | ... -When you try to install using the command ``python -m pip install .`` -(in ``extensions/``), the package will be installed with the same structure as -``extensions/mlx_sample_extensions`` and the C++ and metal library will be -copied along with the python binding since they are specified as ``package_data``. +When you try to install using the command ``python -m pip install .`` (in +``extensions/``), the package will be installed with the same structure as +``extensions/mlx_sample_extensions`` and the C++ and Metal library will be +copied along with the Python binding since they are specified as +``package_data``. Usage ----- -After installing the extension as described above, you should be able to simply -import the python package and play with it as you would any other MLX operation! +After installing the extension as described above, you should be able to simply +import the Python package and play with it as you would any other MLX operation. -Let's looks at a simple script and it's results! +Let's look at a simple script and its results: .. code-block:: python @@ -874,12 +836,12 @@ Output: c correctness: True Results -^^^^^^^^^^^^^^^^ +^^^^^^^ -Let's run a quick benchmark and see how our new ``axpby`` operation compares -with the naive :meth:`simple_axpby` we defined at first on the CPU. +Let's run a quick benchmark and see how our new ``axpby`` operation compares +with the naive :meth:`simple_axpby` we first defined on the CPU. -.. code-block:: python +.. code-block:: python import mlx.core as mx from mlx_sample_extensions import axpby @@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU. alpha = 4.0 beta = 2.0 - mx.eval((x, y)) + mx.eval(x, y) def bench(f): # Warm up @@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU. print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s") -Results: - -.. code-block:: - - Simple axpby: 0.114 s | Custom axpby: 0.109 s - -We see some modest improvements right away! +The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see +modest improvements right away! This operation is now good to be used to build other operations, in :class:`mlx.nn.Module` calls, and also as a part of graph transformations like -:meth:`grad`! +:meth:`grad`. Scripts ------- .. admonition:: Download the code - The full example code is available in `mlx `_. - -.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_ + The full example code is available in `mlx `_. .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc .. _Metal: https://developer.apple.com/documentation/metal?language=objc .. _Metal-cpp: https://developer.apple.com/metal/cpp/ .. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf .. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc -.. _PyBind11: https://pybind11.readthedocs.io/en/stable/ +.. _nanobind: https://nanobind.readthedocs.io/en/latest/ diff --git a/docs/build/html/_sources/dev/metal_debugger.rst b/docs/build/html/_sources/dev/metal_debugger.rst index b0d7db9d0..94d25258c 100644 --- a/docs/build/html/_sources/dev/metal_debugger.rst +++ b/docs/build/html/_sources/dev/metal_debugger.rst @@ -1,29 +1,46 @@ Metal Debugger ============== +.. currentmodule:: mlx.core + Profiling is a key step for performance optimization. You can build MLX with -the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization -workflow. The ``MLX_METAL_DEBUG`` debug option: +the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and +optimization workflow. The ``MLX_METAL_DEBUG`` debug option: * Records source during Metal compilation, for later inspection while debugging. * Labels Metal objects such as command queues, improving capture readability. -The ``metal::start_capture`` function initiates a capture of all MLX GPU work. +To build with debugging enabled in Python prepend +``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call. -.. code-block:: C++ +The :func:`metal.start_capture` function initiates a capture of all MLX GPU +work. - int main() { - metal::start_capture("/Users/Jane/Developer/MLX.gputrace"); +.. note:: - auto a = arange(10.f, 20.f, 1.f, float32); - auto b = arange(30.f, 40.f, 1.f, float32); - auto c = add(a, b); + To capture a GPU trace you must run the application with + ``MTL_CAPTURE_ENABLED=1``. - eval(c); +.. code-block:: python - metal::stop_capture(); - } + import mlx.core as mx + + a = mx.random.uniform(shape=(512, 512)) + b = mx.random.uniform(shape=(512, 512)) + mx.eval(a, b) + + trace_file = "mlx_trace.gputrace" + + if not mx.metal.start_capture(trace_file): + print("Make sure to run with MTL_CAPTURE_ENABLED=1 and " + f"that the path {trace_file} does not already exist.") + exit(1) + + for _ in range(10): + mx.eval(mx.add(a, b)) + + mx.metal.stop_capture() You can open and replay the GPU trace in Xcode. The ``Dependencies`` view has a great overview of all operations. Checkout the `Metal debugger @@ -35,8 +52,8 @@ documentation`_ for more information. Xcode Workflow -------------- -You can skip saving to a path by running within Xcode. First, generate an Xcode -project using CMake. +You can skip saving to a path by running within Xcode. First, generate an +Xcode project using CMake. .. code-block:: diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.expm1.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.expm1.rst new file mode 100644 index 000000000..76f1a25a5 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.expm1.rst @@ -0,0 +1,6 @@ +mlx.core.expm1 +============== + +.. currentmodule:: mlx.core + +.. autofunction:: expm1 \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.meshgrid.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.meshgrid.rst new file mode 100644 index 000000000..ba81a5342 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.meshgrid.rst @@ -0,0 +1,6 @@ +mlx.core.meshgrid +================= + +.. currentmodule:: mlx.core + +.. autofunction:: meshgrid \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.metal.start_capture.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.start_capture.rst new file mode 100644 index 000000000..ecf158b75 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.start_capture.rst @@ -0,0 +1,6 @@ +mlx.core.metal.start\_capture +============================= + +.. currentmodule:: mlx.core.metal + +.. autofunction:: start_capture \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.metal.stop_capture.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.stop_capture.rst new file mode 100644 index 000000000..a35af8123 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.stop_capture.rst @@ -0,0 +1,6 @@ +mlx.core.metal.stop\_capture +============================ + +.. currentmodule:: mlx.core.metal + +.. autofunction:: stop_capture \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.random.multivariate_normal.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.random.multivariate_normal.rst new file mode 100644 index 000000000..3f53fe35d --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.random.multivariate_normal.rst @@ -0,0 +1,6 @@ +mlx.core.random.multivariate\_normal +==================================== + +.. currentmodule:: mlx.core.random + +.. autofunction:: multivariate_normal \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.std.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.std.rst new file mode 100644 index 000000000..17c467b81 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.std.rst @@ -0,0 +1,6 @@ +mlx.core.std +============ + +.. currentmodule:: mlx.core + +.. autofunction:: std \ No newline at end of file diff --git a/docs/build/html/_sources/python/metal.rst b/docs/build/html/_sources/python/metal.rst index c11deb4fa..c92b18936 100644 --- a/docs/build/html/_sources/python/metal.rst +++ b/docs/build/html/_sources/python/metal.rst @@ -3,7 +3,7 @@ Metal .. currentmodule:: mlx.core.metal -.. autosummary:: +.. autosummary:: :toctree: _autosummary is_available @@ -12,3 +12,5 @@ Metal get_cache_memory set_memory_limit set_cache_limit + start_capture + stop_capture diff --git a/docs/build/html/_sources/python/ops.rst b/docs/build/html/_sources/python/ops.rst index a10b126af..1e934befd 100644 --- a/docs/build/html/_sources/python/ops.rst +++ b/docs/build/html/_sources/python/ops.rst @@ -5,13 +5,13 @@ Operations .. currentmodule:: mlx.core -.. autosummary:: +.. autosummary:: :toctree: _autosummary abs add all - allclose + allclose any arange arccos @@ -51,6 +51,7 @@ Operations erf erfinv exp + expm1 expand_dims eye flatten @@ -83,6 +84,7 @@ Operations max maximum mean + meshgrid min minimum moveaxis @@ -117,6 +119,7 @@ Operations square squeeze stack + std stop_gradient subtract sum diff --git a/docs/build/html/_sources/python/random.rst b/docs/build/html/_sources/python/random.rst index 706378f9d..d08d5a7df 100644 --- a/docs/build/html/_sources/python/random.rst +++ b/docs/build/html/_sources/python/random.rst @@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG. gumbel key normal + multivariate_normal randint seed split diff --git a/docs/build/html/_sources/usage/lazy_evaluation.rst b/docs/build/html/_sources/usage/lazy_evaluation.rst index e41fcbe0b..bd64f919c 100644 --- a/docs/build/html/_sources/usage/lazy_evaluation.rst +++ b/docs/build/html/_sources/usage/lazy_evaluation.rst @@ -18,7 +18,7 @@ describe below. Transforming Compute Graphs ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Lazy evaluation let's us record a compute graph without actually doing any +Lazy evaluation lets us record a compute graph without actually doing any computations. This is useful for function transformations like :func:`grad` and :func:`vmap` and graph optimizations. diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index 3ab7fc631..0e02a9a5b 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,5 +1,5 @@ const DOCUMENTATION_OPTIONS = { - VERSION: '0.9.0', + VERSION: '0.10.0', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html index 7506db223..ed9e22660 100644 --- a/docs/build/html/cpp/ops.html +++ b/docs/build/html/cpp/ops.html @@ -8,7 +8,7 @@ - Operations — MLX 0.9.0 documentation + Operations — MLX 0.10.0 documentation @@ -36,7 +36,7 @@ - + @@ -131,8 +131,8 @@ - MLX 0.9.0 documentation - Home - + MLX 0.10.0 documentation - Home + @@ -286,6 +286,7 @@
  • mlx.core.erf
  • mlx.core.erfinv
  • mlx.core.exp
  • +
  • mlx.core.expm1
  • mlx.core.expand_dims
  • mlx.core.eye
  • mlx.core.flatten
  • @@ -318,6 +319,7 @@
  • mlx.core.max
  • mlx.core.maximum
  • mlx.core.mean
  • +
  • mlx.core.meshgrid
  • mlx.core.min
  • mlx.core.minimum
  • mlx.core.moveaxis
  • @@ -352,6 +354,7 @@
  • mlx.core.square
  • mlx.core.squeeze
  • mlx.core.stack
  • +
  • mlx.core.std
  • mlx.core.stop_gradient
  • mlx.core.subtract
  • mlx.core.sum
  • @@ -379,6 +382,7 @@
  • mlx.core.random.gumbel
  • mlx.core.random.key
  • mlx.core.random.normal
  • +
  • mlx.core.random.multivariate_normal
  • mlx.core.random.randint
  • mlx.core.random.seed
  • mlx.core.random.split
  • @@ -432,6 +436,8 @@
  • mlx.core.metal.get_cache_memory
  • mlx.core.metal.set_memory_limit
  • mlx.core.metal.set_cache_limit
  • +
  • mlx.core.metal.start_capture
  • +
  • mlx.core.metal.stop_capture
  • Neural Networks
  • Neural Networks
      @@ -763,12 +769,12 @@ document.write(`
    • Operations and Primitives
    • Implementing the Primitive
    • @@ -796,61 +802,45 @@ document.write(`

      Developer Documentation#

      -

      MLX provides a open and flexible backend to which users may add operations -and specialized implementations without much hassle. While the library supplies -efficient operations that can be used and composed for any number of -applications, there may arise cases where new functionalities or highly -optimized implementations are needed. For such cases, you may design and -implement your own operations that link to and build on top of mlx.core. -We will introduce the inner-workings of MLX and go over a simple example to -learn the steps involved in adding new operations to MLX with your own CPU -and GPU implementations.

      +

      You can extend MLX with custom operations on the CPU or GPU. This guide +explains how to do that with a simple example.

      Introducing the Example#

      -

      Let’s say that you would like an operation that takes in two arrays, -x and y, scales them both by some coefficients alpha and beta -respectively, and then adds them together to get the result -z = alpha * x + beta * y. Well, you can very easily do that by just -writing out a function as follows:

      +

      Let’s say you would like an operation that takes in two arrays, x and +y, scales them both by coefficients alpha and beta respectively, +and then adds them together to get the result z = alpha * x + beta * y. +You can do that in MLX directly:

      import mlx.core as mx
       
       def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
           return alpha * x + beta * y
       
      -

      This function performs that operation while leaving the implementations and -differentiation to MLX.

      -

      However, you work with vector math libraries often and realize that the -axpby routine defines the same operation Y = (alpha * X) + (beta * Y). -You would really like the part of your applications that does this operation -on the CPU to be very fast - so you decide that you want it to rely on the -axpby routine provided by the Accelerate framework. Continuing to impose -our assumptions on to you, let’s also assume that you want to learn how to add -your own implementation for the gradients of your new operation while going -over the ins-and-outs of the MLX framework.

      -

      Well, what a coincidence! You are in the right place. Over the course of this -example, we will learn:

      +

      This function performs that operation while leaving the implementation and +function transformations to MLX.

      +

      However you may need to customize the underlying implementation, perhaps to +make it faster or for custom differentiation. In this tutorial we will go +through adding custom extensions. It will cover:

        -
      • The structure of the MLX library from the frontend API to the backend implementations.

      • -
      • How to implement your own CPU backend that redirects to Accelerate when appropriate (and a fallback if needed).

      • -
      • How to implement your own GPU implementation using metal.

      • -
      • How to add your own vjp and jvp.

      • -
      • How to build your implementations, link them to MLX, and bind them to python.

      • +
      • The structure of the MLX library.

      • +
      • Implementing a CPU operation that redirects to Accelerate when appropriate.

      • +
      • Implementing a GPU operation using metal.

      • +
      • Adding the vjp and jvp function transformation.

      • +
      • Building a custom extension and binding it to python.

      Operations and Primitives#

      -

      In one sentence, operations in MLX build the computation graph, and primitives -provide the rules for evaluation and transformations of said graph. Let’s start -by discussing operations in more detail.

      +

      Operations in MLX build the computation graph. Primitives provide the rules for +evaluating and transforming the graph. Let’s start by discussing operations in +more detail.

      Operations#

      -

      Operations are the frontend functions that operate on arrays. They are defined -in the C++ API (Operations) and then we provide bindings to these -operations in the Python API (Operations).

      -

      We would like an operation, axpby() that takes in two arrays x and y, -and two scalars, alpha and beta. This is how we would define it in the -C++ API:

      +

      Operations are the front-end functions that operate on arrays. They are defined +in the C++ API (Operations), and the Python API (Operations) binds them.

      +

      We would like an operation, axpby() that takes in two arrays x and +y, and two scalars, alpha and beta. This is how to define it in +C++:

      /**
       *  Scale and sum two vectors element-wise
       *  z = alpha * x + beta * y
      @@ -867,9 +857,7 @@ C++ API:

      );
      -

      This operation itself can call other operations within it if needed. So, the -simplest way to go about implementing this operation would be do so in terms -of existing operations.

      +

      The simplest way to this operation is in terms of existing operations:

      array axpby(
           const array& x, // Input array x
           const array& y, // Input array y
      @@ -886,19 +874,17 @@ of existing operations.

      }
      -

      However, as we discussed earlier, this is not our goal. The operations themselves -do not contain the implementations that act on the data, nor do they contain the -rules of transformations. Rather, they are an easy to use interface that build -on top of the building blocks we call Primitive.

      +

      The operations themselves do not contain the implementations that act on the +data, nor do they contain the rules of transformations. Rather, they are an +easy to use interface that use Primitive building blocks.

      Primitives#

      A Primitive is part of the computation graph of an array. It -defines how to create an output given a set of input array . Further, -a Primitive is a class that contains rules on how it is evaluated -on the CPU or GPU, and how it acts under transformations such as vjp and -jvp. These words on their own can be a bit abstract, so lets take a step -back and go to our example to give ourselves a more concrete image.

      +defines how to create outputs arrays given a input arrays. Further, a +Primitive has methods to run on the CPU or GPU and for function +transformations such as vjp and jvp. Lets go back to our example to be +more concrete:

      class Axpby : public Primitive {
         public:
           explicit Axpby(Stream stream, float alpha, float beta)
      @@ -911,11 +897,15 @@ back and go to our example to give ourselves a more concrete image.

      * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ - void eval_cpu(const std::vector<array>& inputs, array& out) override; - void eval_gpu(const std::vector<array>& inputs, array& out) override; + void eval_cpu( + const std::vector<array>& inputs, + std::vector<array>& outputs) override; + void eval_gpu( + const std::vector<array>& inputs, + std::vector<array>& outputs) override; /** The Jacobian-vector product. */ - array jvp( + std::vector<array> jvp( const std::vector<array>& primals, const std::vector<array>& tangents, const std::vector<int>& argnums) override; @@ -924,7 +914,8 @@ back and go to our example to give ourselves a more concrete image.

      std::vector<array> vjp( const std::vector<array>& primals, const array& cotan, - const std::vector<int>& argnums) override; + const std::vector<int>& argnums, + const std::vector<array>& outputs) override; /** * The primitive must know how to vectorize itself across @@ -932,7 +923,7 @@ back and go to our example to give ourselves a more concrete image.

      * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ - std::pair<array, int> vmap( + virtual std::pair<std::vector<array>, std::vector<int>> vmap( const std::vector<array>& inputs, const std::vector<int>& axes) override; @@ -953,20 +944,20 @@ back and go to our example to give ourselves a more concrete image.

      };
      -

      The Axpby class derives from the base Primitive class and -follows the above demonstrated interface. Axpby treats alpha and -beta as parameters. It then provides implementations of how the array out -is produced given inputs through Axpby::eval_cpu() and -Axpby::eval_gpu(). Further, it provides rules of transformations in -Axpby::jvp(), Axpby::vjp(), and Axpby::vmap().

      +

      The Axpby class derives from the base Primitive class. The +Axpby treats alpha and beta as parameters. It then provides +implementations of how the output array is produced given the inputs through +Axpby::eval_cpu() and Axpby::eval_gpu(). It also provides rules +of transformations in Axpby::jvp(), Axpby::vjp(), and +Axpby::vmap().

      -
      -

      Using the Primitives#

      -

      Operations can use this Primitive to add a new array to -the computation graph. An array can be constructed by providing its -data type, shape, the Primitive that computes it, and the -array inputs that are passed to the primitive.

      -

      Let’s re-implement our operation now in terms of our Axpby primitive.

      +
      +

      Using the Primitive#

      +

      Operations can use this Primitive to add a new array to the +computation graph. An array can be constructed by providing its data +type, shape, the Primitive that computes it, and the array +inputs that are passed to the primitive.

      +

      Let’s reimplement our operation now in terms of our Axpby primitive.

      array axpby(
           const array& x, // Input array x
           const array& y, // Input array y
      @@ -1012,25 +1003,24 @@ data type, shape, the 
       

      Implementing the Primitive#

      -

      No computation happens when we call the operation alone. In effect, the -operation only builds the computation graph. When we evaluate the output -array, MLX schedules the execution of the computation graph, and calls -Axpby::eval_cpu() or Axpby::eval_gpu() depending on the -stream/device specified by the user.

      +

      No computation happens when we call the operation alone. The operation only +builds the computation graph. When we evaluate the output array, MLX schedules +the execution of the computation graph, and calls Axpby::eval_cpu() or +Axpby::eval_gpu() depending on the stream/device specified by the user.

      Warning

      When Primitive::eval_cpu() or Primitive::eval_gpu() are called, no memory has been allocated for the output array. It falls on the implementation -of these functions to allocate memory as needed

      +of these functions to allocate memory as needed.

      -
      -

      Implementing the CPU Backend#

      -

      Let’s start by trying to implement a naive and generic version of +

      +

      Implementing the CPU Back-end#

      +

      Let’s start by implementing a naive and generic version of Axpby::eval_cpu(). We declared this as a private member function of Axpby earlier called Axpby::eval().

      Our naive method will go over each element of the output array, find the corresponding input elements of x and y and perform the operation -pointwise. This is captured in the templated function axpby_impl().

      +point-wise. This is captured in the templated function axpby_impl().

      template <typename T>
       void axpby_impl(
               const array& x,
      @@ -1066,16 +1056,16 @@ pointwise. This is captured in the templated function }
       
      -

      Now, we would like our implementation to be able to do this pointwise operation -for all incoming floating point arrays. Accordingly, we add dispatches for -float32, float16, bfloat16 and complex64. We throw an error -if we encounter an unexpected type.

      +

      Our implementation should work for all incoming floating point arrays. +Accordingly, we add dispatches for float32, float16, bfloat16 and +complex64. We throw an error if we encounter an unexpected type.

      /** Fall back implementation for evaluation on CPU */
      -void Axpby::eval(const std::vector<array>& inputs, array& out) {
      -    // Check the inputs (registered in the op while constructing the out array)
      -    assert(inputs.size() == 2);
      +void Axpby::eval(
      +  const std::vector<array>& inputs,
      +  const std::vector<array>& outputs) {
           auto& x = inputs[0];
           auto& y = inputs[1];
      +    auto& out = outputs[0];
       
           // Dispatch to the correct dtype
           if (out.dtype() == float32) {
      @@ -1088,29 +1078,27 @@ if we encounter an unexpected type.

      return axpby_impl<complex64_t>(x, y, out, alpha_, beta_); } else { throw std::runtime_error( - "Axpby is only supported for floating point types."); + "[Axpby] Only supports floating point types."); } }
      -

      We have a fallback implementation! Now, to do what we are really here to do. -Remember we wanted to use the axpby routine provided by the Accelerate -framework? Well, there are 3 complications to keep in mind:

      +

      This is good as a fallback implementation. We can use the axpby routine +provided by the Accelerate framework for a faster implementation in certain +cases:

      1. Accelerate does not provide implementations of axpby for half precision -floats. We can only direct to it for float32 types

      2. -
      3. Accelerate assumes the inputs x and y are contiguous and all elements -have fixed strides between them. Possibly due to broadcasts and transposes, -we aren’t guaranteed that the inputs fit this requirement. We can -only direct to Accelerate if both x and y are row contiguous or -column contiguous.

      4. -
      5. Accelerate performs the routine Y = (alpha * X) + (beta * Y) inplace. -MLX expects to write out the answer to a new array. We must copy the elements -of y into the output array and use that as an input to axpby

      6. +floats. We can only use it for float32 types.

        +
      7. Accelerate assumes the inputs x and y are contiguous and all +elements have fixed strides between them. We only direct to Accelerate +if both x and y are row contiguous or column contiguous.

      8. +
      9. Accelerate performs the routine Y = (alpha * X) + (beta * Y) in-place. +MLX expects to write the output to a new array. We must copy the elements +of y into the output and use that as an input to axpby.

      -

      Let’s write out an implementation that uses Accelerate in the right conditions. -It must simply allocate data for the output, copy elements of y into it, -and then call the catlas_saxpby() from accelerate.

      +

      Let’s write an implementation that uses Accelerate in the right conditions. +It allocates data for the output, copies y into it, and then calls the +catlas_saxpby() from accelerate.

      template <typename T>
       void axpby_impl_accelerate(
               const array& x,
      @@ -1121,17 +1109,7 @@ and then call the     // Accelerate library provides catlas_saxpby which does
           // Y = (alpha * X) + (beta * Y) in place
           // To use it, we first copy the data in y over to the output array
      -
      -    // This specialization requires both x and y be contiguous in the same mode
      -    // i.e: corresponding linear indices in both point to corresponding elements
      -    // The data in the output array is allocated to match the strides in y
      -    // such that x, y, and out are contiguous in the same mode and
      -    // no transposition is needed
      -    out.set_data(
      -        allocator::malloc_or_wait(y.data_size() * out.itemsize()),
      -        y.data_size(),
      -        y.strides(),
      -        y.flags());
      +    out.set_data(allocator::malloc_or_wait(out.nbytes()));
       
           // We then copy over the elements using the contiguous vector specialization
           copy_inplace(y, out, CopyType::Vector);
      @@ -1155,14 +1133,17 @@ and then call the }
       
      -

      Great! But what about the inputs that do not fit the criteria for accelerate? -Luckily, we can always just direct back to Axpby::eval().

      -

      With this in mind, lets finally implement our Axpby::eval_cpu().

      +

      For inputs that do not fit the criteria for accelerate, we fall back to +Axpby::eval(). With this in mind, let’s finish our +Axpby::eval_cpu().

      /** Evaluate primitive on CPU using accelerate specializations */
      -void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
      +void Axpby::eval_cpu(
      +  const std::vector<array>& inputs,
      +  const std::vector<array>& outputs) {
           assert(inputs.size() == 2);
           auto& x = inputs[0];
           auto& y = inputs[1];
      +    auto& out = outputs[0];
       
           // Accelerate specialization for contiguous single precision float arrays
           if (out.dtype() == float32 &&
      @@ -1172,33 +1153,32 @@ Luckily, we can always just direct back to         return;
           }
       
      -    // Fall back to common backend if specializations are not available
      -    eval(inputs, out);
      +    // Fall back to common back-end if specializations are not available
      +    eval(inputs, outputs);
       }
       
      -

      We have now hit a milestone! Just this much is enough to run the operation -axpby() on a CPU stream!

      -

      If you do not plan on running the operation on the GPU or using transforms on +

      Just this much is enough to run the operation axpby() on a CPU stream! If +you do not plan on running the operation on the GPU or using transforms on computation graphs that contain Axpby, you can stop implementing the primitive here and enjoy the speed-ups you get from the Accelerate library.

      -
      -

      Implementing the GPU Backend#

      +
      +

      Implementing the GPU Back-end#

      Apple silicon devices address their GPUs using the Metal shading language, and -all GPU kernels in MLX are written using metal.

      +GPU kernels in MLX are written using Metal.

      Note

      -

      Here are some helpful resources if you are new to metal!

      +

      Here are some helpful resources if you are new to Metal:

      -

      Let’s keep the GPU algorithm simple. We will launch exactly as many threads -as there are elements in the output. Each thread will pick the element it needs -from x and y, do the pointwise operation, and then update its assigned +

      Let’s keep the GPU kernel simple. We will launch exactly as many threads as +there are elements in the output. Each thread will pick the element it needs +from x and y, do the point-wise operation, and update its assigned element in the output.

      template <typename T>
       [[kernel]] void axpby_general(
      @@ -1223,8 +1203,7 @@ element in the output.

      We then need to instantiate this template for all floating point types and give -each instantiation a unique host name so we can identify the right kernel for -each data type.

      +each instantiation a unique host name so we can identify it.

      #define instantiate_axpby(type_name, type)              \
           template [[host_name("axpby_general_" #type_name)]] \
           [[kernel]] void axpby_general<type>(                \
      @@ -1245,25 +1224,18 @@ each data type.

      instantiate_axpby(complex64, complex64_t);
      -

      This kernel will be compiled into a metal library mlx_ext.metallib as we -will see later in Building with CMake. In the following example, we -assume that the library mlx_ext.metallib will always be co-located with -the executable/ shared-library calling the register_library() function. -The register_library() function takes the library’s name and potential -path (or in this case, a function that can produce the path of the metal -library) and tries to load that library if it hasn’t already been registered -by the relevant static mlx::core::metal::Device object. This is why, -it is important to package your C++ library with the metal library. We will -go over this process in more detail later.

      -

      The logic to determine the kernel, set the inputs, resolve the grid dimensions -and dispatch it to the GPU are contained in Axpby::eval_gpu() as shown +

      The logic to determine the kernel, set the inputs, resolve the grid dimensions, +and dispatch to the GPU are contained in Axpby::eval_gpu() as shown below.

      /** Evaluate primitive on GPU */
      -void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
      +void Axpby::eval_gpu(
      +  const std::vector<array>& inputs,
      +  std::vector<array>& outputs) {
           // Prepare inputs
           assert(inputs.size() == 2);
           auto& x = inputs[0];
           auto& y = inputs[1];
      +    auto& out = outputs[0];
       
           // Each primitive carries the stream it should execute on
           // and each stream carries its device identifiers
      @@ -1274,7 +1246,7 @@ below.

      // Allocate output memory out.set_data(allocator::malloc_or_wait(out.nbytes())); - // Resolve name of kernel (corresponds to axpby.metal) + // Resolve name of kernel std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); @@ -1328,24 +1300,21 @@ below.

      We can now call the axpby() operation on both the CPU and the GPU!

      -

      A few things to note about MLX and metal before moving on. MLX keeps track -of the active compute_encoder. We rely on d.get_command_encoder() -to give us the active metal compute command encoder instead of building a -new one and calling compute_encoder->end_encoding() at the end. -MLX keeps adding kernels (compute pipelines) to the active command encoder -until some specified limit is hit or the compute encoder needs to be flushed -for synchronization. MLX also handles enqueuing and committing the associated -command buffers as needed. We suggest taking a deeper dive into -metal::Device if you would like to study this routine further.

      +

      A few things to note about MLX and Metal before moving on. MLX keeps track of +the active command_buffer and the MTLCommandBuffer to which it is +associated. We rely on d.get_command_encoder() to give us the active +metal compute command encoder instead of building a new one and calling +compute_encoder->end_encoding() at the end. MLX adds kernels (compute +pipelines) to the active command buffer until some specified limit is hit or +the command buffer needs to be flushed for synchronization.

      Primitive Transforms#

      -

      Now that we have come this far, let’s also learn how to add implementations to -transformations in a Primitive. These transformations can be built on -top of our operations, including the one we just defined now. Which then gives -us the following Axpby::jvp() and Axpby::vjp() implementations.

      +

      Next, let’s add implementations for transformations in a Primitive. +These transformations can be built on top of other operations, including the +one we just defined:

      /** The Jacobian-vector product. */
      -array Axpby::jvp(
      +std::vector<array> Axpby::jvp(
               const std::vector<array>& primals,
               const std::vector<array>& tangents,
               const std::vector<int>& argnums) {
      @@ -1360,12 +1329,12 @@ us the following     if (argnums.size() > 1) {
               auto scale = argnums[0] == 0 ? alpha_ : beta_;
               auto scale_arr = array(scale, tangents[0].dtype());
      -        return multiply(scale_arr, tangents[0], stream());
      +        return {multiply(scale_arr, tangents[0], stream())};
           }
           // If, argnums = {0, 1}, we take contributions from both
           // which gives us jvp = tangent_x * alpha + tangent_y * beta
           else {
      -        return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
      +        return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
           }
       }
       
      @@ -1373,26 +1342,27 @@ us the following
      /** The vector-Jacobian product. */
       std::vector<array> Axpby::vjp(
               const std::vector<array>& primals,
      -        const array& cotan,
      -        const std::vector<int>& argnums) {
      +        const std::vector<array>& cotangents,
      +        const std::vector<int>& argnums,
      +        const std::vector<int>& /* unused */) {
           // Reverse mode diff
           std::vector<array> vjps;
           for (auto arg : argnums) {
               auto scale = arg == 0 ? alpha_ : beta_;
      -        auto scale_arr = array(scale, cotan.dtype());
      -        vjps.push_back(multiply(scale_arr, cotan, stream()));
      +        auto scale_arr = array(scale, cotangents[0].dtype());
      +        vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
           }
           return vjps;
       }
       
      -

      Finally, you need not have a transformation fully defined to start using your -own Primitive.

      +

      Note, a transformation does not need to be fully defined to start using +the Primitive.

      /** Vectorize primitive along given axis */
      -std::pair<array, int> Axpby::vmap(
      +std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
               const std::vector<array>& inputs,
               const std::vector<int>& axes) {
      -    throw std::runtime_error("Axpby has no vmap implementation.");
      +    throw std::runtime_error("[Axpby] vmap not implemented.");
       }
       
      @@ -1416,64 +1386,63 @@ own
    • extensions/axpby/ defines the C++ extension library

    • extensions/mlx_sample_extensions sets out the structure for the -associated python package

    • -
    • extensions/bindings.cpp provides python bindings for our operation

    • +associated Python package

      +
    • extensions/bindings.cpp provides Python bindings for our operation

    • extensions/CMakeLists.txt holds CMake rules to build the library and -python bindings

    • +Python bindings

    • extensions/setup.py holds the setuptools rules to build and install -the python package

    • +the Python package

    Binding to Python#

    -

    We use PyBind11 to build a Python API for the C++ library. Since bindings for +

    We use nanobind to build a Python API for the C++ library. Since bindings for components such as mlx.core.array, mlx.core.stream, etc. are -already provided, adding our axpby() is simple!

    -
    PYBIND11_MODULE(mlx_sample_extensions, m) {
    -    m.doc() = "Sample C++ and metal extensions for MLX";
    +already provided, adding our axpby() is simple.

    +
    NB_MODULE(_ext, m) {
    +     m.doc() = "Sample extension for MLX";
     
    -    m.def(
    -        "axpby",
    -        &axpby,
    -        "x"_a,
    -        "y"_a,
    -        py::pos_only(),
    -        "alpha"_a,
    -        "beta"_a,
    -        py::kw_only(),
    -        "stream"_a = py::none(),
    -        R"pbdoc(
    -            Scale and sum two vectors element-wise
    -            ``z = alpha * x + beta * y``
    +     m.def(
    +         "axpby",
    +         &axpby,
    +         "x"_a,
    +         "y"_a,
    +         "alpha"_a,
    +         "beta"_a,
    +         nb::kw_only(),
    +         "stream"_a = nb::none(),
    +         R"(
    +             Scale and sum two vectors element-wise
    +             ``z = alpha * x + beta * y``
     
    -            Follows numpy style broadcasting between ``x`` and ``y``
    -            Inputs are upcasted to floats if needed
    +             Follows numpy style broadcasting between ``x`` and ``y``
    +             Inputs are upcasted to floats if needed
     
    -            Args:
    -                x (array): Input array.
    -                y (array): Input array.
    -                alpha (float): Scaling factor for ``x``.
    -                beta (float): Scaling factor for ``y``.
    +             Args:
    +                 x (array): Input array.
    +                 y (array): Input array.
    +                 alpha (float): Scaling factor for ``x``.
    +                 beta (float): Scaling factor for ``y``.
     
    -            Returns:
    -                array: ``alpha * x + beta * y``
    -        )pbdoc");
    -}
    +             Returns:
    +                 array: ``alpha * x + beta * y``
    +         )");
    + }
     

    Most of the complexity in the above example comes from additional bells and whistles such as the literal names and doc-strings.

    Warning

    -

    mlx.core needs to be imported before importing -mlx_sample_extensions as defined by the pybind11 module above to +

    mlx.core must be imported before importing +mlx_sample_extensions as defined by the nanobind module above to ensure that the casters for mlx.core components like mlx.core.array are available.

    Building with CMake#

    -

    Building the C++ extension library itself is simple, it only requires that you -find_package(MLX CONFIG) and then link it to your library.

    +

    Building the C++ extension library only requires that you find_package(MLX +CONFIG) and then link it to your library.

    # Add library
     add_library(mlx_ext)
     
    @@ -1493,11 +1462,11 @@ ensure that the casters for target_link_libraries(mlx_ext PUBLIC mlx)
     
    -

    We also need to build the attached metal library. For convenience, we provide a +

    We also need to build the attached Metal library. For convenience, we provide a mlx_build_metallib() function that builds a .metallib target given sources, headers, destinations, etc. (defined in cmake/extension.cmake and automatically imported with MLX package).

    -

    Here is what that looks like in practice!

    +

    Here is what that looks like in practice:

    # Build metallib
     if(MLX_BUILD_METAL)
     
    @@ -1517,15 +1486,17 @@ automatically imported with MLX package).

    endif()
    -

    Finally, we build the Pybind11 bindings

    -
    pybind11_add_module(
    -    mlx_sample_extensions
    -    ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
    +

    Finally, we build the nanobind bindings

    +
    nanobind_add_module(
    +  _ext
    +  NB_STATIC STABLE_ABI LTO NOMINSIZE
    +  NB_DOMAIN mlx
    +  ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
     )
    -target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
    +target_link_libraries(_ext PRIVATE mlx_ext)
     
     if(BUILD_SHARED_LIBS)
    -    target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
    +  target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
     endif()
     
    @@ -1533,7 +1504,7 @@ automatically imported with MLX package).

    Building with setuptools#

    Once we have set out the CMake build rules as described above, we can use the -build utilities defined in mlx.extension for a simple build process.

    +build utilities defined in mlx.extension:

    from mlx import extension
     from setuptools import setup
     
    @@ -1542,13 +1513,13 @@ build utilities defined in name="mlx_sample_extensions",
             version="0.0.0",
             description="Sample C++ and Metal extensions for MLX primitives.",
    -        ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
    +        ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
             cmdclass={"build_ext": extension.CMakeBuild},
    -        packages = ["mlx_sample_extensions"],
    -        package_dir = {"": "mlx_sample_extensions"},
    -        package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
    +        packages=["mlx_sample_extensions"],
    +        package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
    +        extras_require={"dev":[]},
             zip_safe=False,
    -        python_requires=">=3.7",
    +        python_requires=">=3.8",
         )
     
    @@ -1557,34 +1528,36 @@ build utilities defined in extensions/mlx_sample_extensions as the package directory even though it only contains a __init__.py to ensure the following:

      -
    • mlx.core is always imported before importing mlx_sample_extensions

    • +
    • mlx.core must be imported before importing _ext

    • The C++ extension library and the metal library are co-located with the python bindings and copied together if the package is installed

    -

    You can build inplace for development using +

    To build the package, first install the build dependencies with pip install +-r requirements.txt. You can then build inplace for development using python setup.py build_ext -j8 --inplace (in extensions/)

    -

    This will result in a directory structure as follows:

    +

    This results in the directory structure:

    extensions
    ├── mlx_sample_extensions
    │ ├── __init__.py
    │ ├── libmlx_ext.dylib # C++ extension library
    │ ├── mlx_ext.metallib # Metal library
    -
    │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
    +
    │ └── _ext.cpython-3x-darwin.so # Python Binding
    -

    When you try to install using the command python -m pip install . -(in extensions/), the package will be installed with the same structure as -extensions/mlx_sample_extensions and the C++ and metal library will be -copied along with the python binding since they are specified as package_data.

    +

    When you try to install using the command python -m pip install . (in +extensions/), the package will be installed with the same structure as +extensions/mlx_sample_extensions and the C++ and Metal library will be +copied along with the Python binding since they are specified as +package_data.

    Usage#

    After installing the extension as described above, you should be able to simply -import the python package and play with it as you would any other MLX operation!

    -

    Let’s looks at a simple script and it’s results!

    +import the Python package and play with it as you would any other MLX operation.

    +

    Let’s look at a simple script and its results:

    import mlx.core as mx
     from mlx_sample_extensions import axpby
     
    @@ -1606,7 +1579,7 @@ import the python package and play with it as you would any other MLX operation!
     

    Results#

    Let’s run a quick benchmark and see how our new axpby operation compares -with the naive simple_axpby() we defined at first on the CPU.

    +with the naive simple_axpby() we first defined on the CPU.

    import mlx.core as mx
     from mlx_sample_extensions import axpby
     import time
    @@ -1624,7 +1597,7 @@ with the naive alpha = 4.0
     beta = 2.0
     
    -mx.eval((x, y))
    +mx.eval(x, y)
     
     def bench(f):
         # Warm up
    @@ -1646,21 +1619,18 @@ with the naive print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
     
    -

    Results:

    -
    Simple axpby: 0.114 s | Custom axpby: 0.109 s
    -
    -
    -

    We see some modest improvements right away!

    +

    The results are Simple axpby: 0.114 s | Custom axpby: 0.109 s. We see +modest improvements right away!

    This operation is now good to be used to build other operations, in mlx.nn.Module calls, and also as a part of graph transformations like -grad()!

    +grad().

    Scripts#

    Download the code

    -

    The full example code is available in mlx.

    +

    The full example code is available in mlx.

    @@ -1713,12 +1683,12 @@ with the naive Operations and Primitives
  • Implementing the Primitive
  • diff --git a/docs/build/html/dev/metal_debugger.html b/docs/build/html/dev/metal_debugger.html index e6a8b2fee..49a253fca 100644 --- a/docs/build/html/dev/metal_debugger.html +++ b/docs/build/html/dev/metal_debugger.html @@ -8,7 +8,7 @@ - Metal Debugger — MLX 0.9.0 documentation + Metal Debugger — MLX 0.10.0 documentation @@ -36,7 +36,7 @@ - + @@ -130,8 +130,8 @@ - MLX 0.9.0 documentation - Home - + MLX 0.10.0 documentation - Home + @@ -285,6 +285,7 @@
  • mlx.core.erf
  • mlx.core.erfinv
  • mlx.core.exp
  • +
  • mlx.core.expm1
  • mlx.core.expand_dims
  • mlx.core.eye
  • mlx.core.flatten
  • @@ -317,6 +318,7 @@
  • mlx.core.max
  • mlx.core.maximum
  • mlx.core.mean
  • +
  • mlx.core.meshgrid
  • mlx.core.min
  • mlx.core.minimum
  • mlx.core.moveaxis
  • @@ -351,6 +353,7 @@
  • mlx.core.square
  • mlx.core.squeeze
  • mlx.core.stack
  • +
  • mlx.core.std
  • mlx.core.stop_gradient
  • mlx.core.subtract
  • mlx.core.sum
  • @@ -378,6 +381,7 @@
  • mlx.core.random.gumbel
  • mlx.core.random.key
  • mlx.core.random.normal
  • +
  • mlx.core.random.multivariate_normal
  • mlx.core.random.randint
  • mlx.core.random.seed
  • mlx.core.random.split
  • @@ -431,6 +435,8 @@
  • mlx.core.metal.get_cache_memory
  • mlx.core.metal.set_memory_limit
  • mlx.core.metal.set_cache_limit
  • +
  • mlx.core.metal.start_capture
  • +
  • mlx.core.metal.stop_capture
  • Neural Networks
      @@ -773,25 +779,39 @@ document.write(`

      Metal Debugger#

      Profiling is a key step for performance optimization. You can build MLX with -the MLX_METAL_DEBUG option to improve the Metal debugging and optimization -workflow. The MLX_METAL_DEBUG debug option:

      +the MLX_METAL_DEBUG option to improve the Metal debugging and +optimization workflow. The MLX_METAL_DEBUG debug option:

      • Records source during Metal compilation, for later inspection while debugging.

      • Labels Metal objects such as command queues, improving capture readability.

      -

      The metal::start_capture function initiates a capture of all MLX GPU work.

      -
      int main() {
      -    metal::start_capture("/Users/Jane/Developer/MLX.gputrace");
      +

      To build with debugging enabled in Python prepend +CMAKE_ARGS="-DMLX_METAL_DEBUG=ON" to the build call.

      +

      The metal.start_capture() function initiates a capture of all MLX GPU +work.

      +
      +

      Note

      +

      To capture a GPU trace you must run the application with +MTL_CAPTURE_ENABLED=1.

      +
      +
      import mlx.core as mx
       
      -    auto a = arange(10.f, 20.f, 1.f, float32);
      -    auto b = arange(30.f, 40.f, 1.f, float32);
      -    auto c = add(a, b);
      +a = mx.random.uniform(shape=(512, 512))
      +b = mx.random.uniform(shape=(512, 512))
      +mx.eval(a, b)
       
      -    eval(c);
      +trace_file = "mlx_trace.gputrace"
       
      -    metal::stop_capture();
      -}
      +if not mx.metal.start_capture(trace_file):
      +  print("Make sure to run with MTL_CAPTURE_ENABLED=1 and "
      +        f"that the path {trace_file} does not already exist.")
      +  exit(1)
      +
      +for _ in range(10):
      +  mx.eval(mx.add(a, b))
      +
      +mx.metal.stop_capture()
       

      You can open and replay the GPU trace in Xcode. The Dependencies view @@ -800,8 +820,8 @@ documentation for more information.

      ../_images/capture.png

      Xcode Workflow#

      -

      You can skip saving to a path by running within Xcode. First, generate an Xcode -project using CMake.

      +

      You can skip saving to a path by running within Xcode. First, generate an +Xcode project using CMake.

      mkdir build && cd build
       cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
       open mlx.xcodeproj
      diff --git a/docs/build/html/examples/linear_regression.html b/docs/build/html/examples/linear_regression.html
      index cab32348f..412df205f 100644
      --- a/docs/build/html/examples/linear_regression.html
      +++ b/docs/build/html/examples/linear_regression.html
      @@ -8,7 +8,7 @@
           
           
       
      -    Linear Regression — MLX 0.9.0 documentation
      +    Linear Regression — MLX 0.10.0 documentation
         
         
         
      @@ -36,7 +36,7 @@
       
         
       
      -    
      +    
           
           
           
      @@ -131,8 +131,8 @@
             
           
           
      -    MLX 0.9.0 documentation - Home
      -    
      +    MLX 0.10.0 documentation - Home
      +    
         
         
       
      @@ -286,6 +286,7 @@
    • mlx.core.erf
    • mlx.core.erfinv
    • mlx.core.exp
    • +
    • mlx.core.expm1
    • mlx.core.expand_dims
    • mlx.core.eye
    • mlx.core.flatten
    • @@ -318,6 +319,7 @@
    • mlx.core.max
    • mlx.core.maximum
    • mlx.core.mean
    • +
    • mlx.core.meshgrid
    • mlx.core.min
    • mlx.core.minimum
    • mlx.core.moveaxis
    • @@ -352,6 +354,7 @@
    • mlx.core.square
    • mlx.core.squeeze
    • mlx.core.stack
    • +
    • mlx.core.std
    • mlx.core.stop_gradient
    • mlx.core.subtract
    • mlx.core.sum
    • @@ -379,6 +382,7 @@
    • mlx.core.random.gumbel
    • mlx.core.random.key
    • mlx.core.random.normal
    • +
    • mlx.core.random.multivariate_normal
    • mlx.core.random.randint
    • mlx.core.random.seed
    • mlx.core.random.split
    • @@ -432,6 +436,8 @@
    • mlx.core.metal.get_cache_memory
    • mlx.core.metal.set_memory_limit
    • mlx.core.metal.set_cache_limit
    • +
    • mlx.core.metal.start_capture
    • +
    • mlx.core.metal.stop_capture
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks - +
  • +
  • meshgrid() (in module mlx.core) +
  • min() (array method) @@ -1540,14 +1552,14 @@ document.write(`
  • smooth_l1_loss() (in module mlx.nn.losses)
  • - - +
  • stack() (in module mlx.core) +
  • +
  • start_capture() (in module mlx.core.metal)
  • state (Module property) @@ -1590,11 +1604,15 @@ document.write(`
  • (Optimizer property)
  • +
  • std() (in module mlx.core) +
  • Step (class in mlx.nn)
  • step() (in module mlx.nn)
  • step_decay() (in module mlx.optimizers) +
  • +
  • stop_capture() (in module mlx.core.metal)
  • stop_gradient() (in module mlx.core)
  • diff --git a/docs/build/html/index.html b/docs/build/html/index.html index cbf578728..58cac276f 100644 --- a/docs/build/html/index.html +++ b/docs/build/html/index.html @@ -8,7 +8,7 @@ - MLX — MLX 0.9.0 documentation + MLX — MLX 0.10.0 documentation @@ -36,7 +36,7 @@ - + @@ -130,8 +130,8 @@ - MLX 0.9.0 documentation - Home - + MLX 0.10.0 documentation - Home + @@ -285,6 +285,7 @@
  • mlx.core.erf
  • mlx.core.erfinv
  • mlx.core.exp
  • +
  • mlx.core.expm1
  • mlx.core.expand_dims
  • mlx.core.eye
  • mlx.core.flatten
  • @@ -317,6 +318,7 @@
  • mlx.core.max
  • mlx.core.maximum
  • mlx.core.mean
  • +
  • mlx.core.meshgrid
  • mlx.core.min
  • mlx.core.minimum
  • mlx.core.moveaxis
  • @@ -351,6 +353,7 @@
  • mlx.core.square
  • mlx.core.squeeze
  • mlx.core.stack
  • +
  • mlx.core.std
  • mlx.core.stop_gradient
  • mlx.core.subtract
  • mlx.core.sum
  • @@ -378,6 +381,7 @@
  • mlx.core.random.gumbel
  • mlx.core.random.key
  • mlx.core.random.normal
  • +
  • mlx.core.random.multivariate_normal
  • mlx.core.random.randint
  • mlx.core.random.seed
  • mlx.core.random.split
  • @@ -431,6 +435,8 @@
  • mlx.core.metal.get_cache_memory
  • mlx.core.metal.set_memory_limit
  • mlx.core.metal.set_cache_limit
  • +
  • mlx.core.metal.start_capture
  • +
  • mlx.core.metal.stop_capture
  • Neural Networks
  • Neural Networks
      diff --git a/docs/build/html/objects.inv b/docs/build/html/objects.inv index df2ded16865de39fb55617c3d1b0d7b4a4782cf1..6e82edc3b7e2445ceb0d6d7a69b7090a84da11ef 100644 GIT binary patch delta 9523 zcmV-3CCu9AN%BgNIRY^-kvcwq$#UE{w%zM1GNO8$ayuPSi{80xS9L|$D!1hdza0Z2 zk(nrQGC`&#fBnP20R%1xT%;Bbr+|CTxd2JbLsdTA=1oU$-_r}r>8^eLPhJwvcdy0I z{@m2BGBWfHAF7IU&#$Vr`R%!{%3uDq+uzs`S;lJCXPG3ds30oR8T;pd7U7t*_OYb6 zwZ-?2^z_tp&p7WURERaz+YSEvzN4fHxyzMRrcAsf@!kdK$OD%Y$Dz)9*3_XS3Ui30 zC{QFt0g;p?ili)IlDupldTm9>xM_N(M+yQZaG{~Sm4)k}u2F1DQ;3uX+Jr*PJBx^zX+!6mGY(*+}n zuHH!O{5_x;{olkCt-lN`spsoZ(R^wUp-8!>Au~qz_&1SwzmD5~v5LSr=&K`CoP?ks zjz(>g;vbGiZIOyhED~5HY>QSP6Iexzm$YEygr$_2O~R%oWeTap`H81D9`}M4o3A{5 zYO#(Ah?nkIlD?loGW?O6kuO zRtfXd;@gQ^jYR%~wB(7R%g_nuCGMM&ar9%-C_iaOYTk~cZz$dg=%v1p4ofN^KDvyG zmjgH2xWFLNzNDgym%YXfgU+Z8%0m+6VTm~? zL{qXBMI-EgA)2OYS&16J9U&lDofFaJ5dUz5GYlrZ;NCrr@K7DlR0khbcA_=yaNnAV z^UyZs1)F^e1JZFE6NgF8NET>Vco}JlCi0E4EaNK}DzNckjH)2GJnh(JECs;@SCIuh zlA-LgVxByWrCcL7sQEl&ho#m?kXlqCbw4)c;%O&;6ht`=B}J3L!${@;5{Ak%cvLV> z9!i=e6rUCpB)_l%$te*qRfT)IH3Fa#<#S?+*GPyO+hXyxRcLoa(R4?cgw+=$k3c|F z7i79FaI&6IRXsdTBSJ}CWTH#Ykz^ALCaqp0>8jTd(KGIqdWPPkqiA&9mgJdsS>1G~ zqRoO!`pCPAr%_KpFLBdSJZaNC8HsJ zvO@yf!B1Bw(D{nQr7<8~+u+lvF(4fu9_Bgqc)|$22?qs5g^rn6OQKzCSetB^m2qiWE5B1dDd=Jl%G!> zD{xdA1Jd=P3_elOdlgg{npLzK_gKWUgA|)$Imyr0X_*Y0QY|HqS1B0;ULnRaav`jg zqk4E16YZ7YbD~RDu^z8t5o`+SGWex>(n+VfgC&aBp0y?J+ZizcooICvSx8{$CJadTBVp2JK>?IaO;LF# z?L9Q&?SG_LE(L-J2V@6W0RW;F*U)Uic4pGXk*!@L*eOTPd}GRF+s0Je6jfySwuy}e z=sVE>Y^vJ_+SScamJCC0WBv5G-lE8Va$FzHMf)XAL-jyY*fr2k{79b@SOv;YsY<5{ z>6l2Ent({k7J$VsnXD2bmjAK3PhcLAerir>k-#fq8>G|vLoBOj3E~*tcY|6lQ=~sn zo5C+P)a-~Hnph(t>eWKuXxgnsz8w_}S6p5z1V}Y6myPH0EtQ?u2S-zR)&|ditAeA6 zJj+fw!L>)Ey@w_{orWWpOM!fZoE&LSYuA{af_Soe-BBQvL z(FzC>qe}*FX)rwW9m#R-jTn#)=N33{U~WE$X7X=5=n1axEA2fr87Z6klwn6TEGS4` z;|Ob*^7q~J183Yj`*Ch7vN<{HdpuTNs57C&LMWVSjHj4Cr9m7#En50Cq zlrYVJq!|F33n>RGCxtTXvWYbkr0#Hh(8$r{AJkNjFo#Ggf@2Z(0L_mx#Z`q001*A8 z9d1j20{Do~Lw*ZtOq{8%ouz?(f~<&-!*MB^zl3+z(<5}X@>a7-r&jJxMUxy z2U--YKE;#G5QprK4zv~FOq4_U@&4a74tO%)bNR%7^d_%J9%gUxl7USz zloTeZRGbHLOWJYH&E5j*lQ&?SoLS#wTAC-)8w@_B{LIRtgU@H#R+x}J=4tx~j@8{n zvcw%(Dc5w$w$s=$!EvwgUX{+b{XD?lxqa>>wxr~k{xnUrSD_(A_Jgw1xld6r&_0D^ z_XD2Sq>@u_sq8C%xKE=coiG!RKS5x@`qvZ*+>ov_(6VSU^~om$pRNH23+jE@w_t8m+_ z=^yGE@2bdo+W%5gyfwFaMb_yK)}jCWg<_hn%G@FO-?UPH0NbXrfDZAOPxrXy-Ux7s z?(PsIcXuuUDx1Ju`qF&SCgX^4avAB}WvFGQG)sF~)8ROFRv_%jJYlq^i94VBO>lE?%7$Ar6 z@0OFQ^)=pqA+`zOP|?yeIA{@NlLddZARV?-I^D^~mx+db*F0T8J2JrRa>MKkNEkG( zPI{&Qd2%LQ!GAE{As<1`8SgO5Lxbh{?(MyQTg>ndz$rF=H7mCAAvqI#IiOAn;J0f1KT%(T zHKezXAJa331h)j%!!!H%@fmY?XG_3p+U*f-yH(?GO9Lt4Czb0BCFMv5gJ4-LfgUYLD#W8zg3X#uj)WoH_%7d*6uQdH~Vp64gL2dV%sM z)E?b`V;iJy+ZIr~H07ZBK=xeA9y8gpp05eX9I!G6SLWc#6ytP@Gf%TeP1IPkJpmO| zTC7|joio!H&^jb&9Y8HGUDN1B zMyKHoYoWy+vC?9-R3k=fJ&LuDAI36A!D58shK2Umj|;^NL)i|`d!zhgpZlc2PtS;d zs;%GAeCRs4%yYGFU?073z*u$Q=Uz|Q@8jp}%_)CB{c2ElA3Z8-4C&h3y)huRHE_O9 z9Xa<8%5S*>4({#IqkF(mp{+VhV|ecH-yS#qXAdJTN|Zc4XI=Uz-a_E^!<$pKxzQj@ zzA(V-Tg&w0*r8SXDvon1x=E~-(7bDZ+7d~Y~jBx$a`A$1en)ftXu(rVN2pj^1kV=($ExJB2)e( zON272x{AHpr2cjt_qR=Q{h_C`+n>4-{SjV47D@imSwm(-{x$vF=AVCtiN*!n)1Fpv zbzP+AUd{+FxXJ)QXik`9|!_n$^|Ie31i^Mcl7(KVGQHuPpu#Dpt?@J@@ZE=;cr(e zfAdOz3u>FS{{EV3A8t{9Y{`H043DFm1yE~aOBiS~>?0Zt-j>IfmQ(TUuc~=G)!LQ^ ztD0$_Tr~mpouexIIxkRdEHvwqbm&KreCpM*B6$OE}cYzRu8vT-~v` z<9+d;lE}jYm*=T3^c)S}%0F%-^{!yY;~FqmBo&q zX<4p0m|VOtU0$KA_S$O1$ob13tUvGaN*-yl!WQ)MiayCxIMl+ig+hS#613N#-&REL zK8diKys0YT^YXHQd11U~4nXVa;Jv}~PKE5|ubSUXSCj#nXWc>ePW`D?PTb^eEA+Kr zMUzX+7^JL)0PXaDk~wLbO{TMeY5a@$V=pZXkZ9(TOzt6rlfFc22Ko11o4*%7cVFet zoAo0weFYxQ?mfNGQg#Y<@0)z6M1xKzRobcpf5pf8Fq}@pz+zg*YPL@@OmOP38-2bQ zSA!7x=pRqQT)Q!aI$|#(_7lN-9k!E4%-`#<8~UUs75S@wv9;FGlRQ36evO9Lek<+$ zI5+-f4mNm8QRVGfeWV?&a|*MrZAQ+IpLTQ5!dI0o$RWbVoRl-dd$puz_nCG%ZTn6* zB$xbyS*Y&d8eiGp#qS$=5fNy(TsHLk?_Y0#q5@T-Oc8qn|NB%D&NIvW=0s~&7xZaO zK85Nu|0#ce-YjNg<`@>AKK`GMmqiEQ<`g%j@XyIs6`K4;b-bZmsq0r}U;0n+o5>YZ zvaA@(!?jCG5vj{E^*a8kgKzZYxld@ES|6pa|6kek(rDf$g*@m2JRK& zqr`rgqi^&r9iaSpfybn+U--A>i9U4X2L5ynXHSeu@}hFTJ{0zlVDDai-f=u)|~PNg3|7;ny>r zXx_m}@O7FrcG$kw?qS8^?F+un3WN`aJ4qoY;hdM5nlIaq_^UC55deBtio8 zMG;S56gGWP1nAcVddSwM7lnKXqOTZ3Xe<|hjP0D@fLa~@^o6ahQHP$isK!r-q$qP+ zL)4NlaYvf06LG2*PJ*cVztKv@|PtI6d(V0c|VT?MxDM8@dLdw6lu(^b5GJ2yO1 z@b$m(lH;|#J18brwoQ<2y|Qhj>;bQO@QEHGHOK32(Za@!w>w{Kj0^ax*jEonR>JDi zTBIT+aET!!qBPO9zOTxUtncR=Gm)z1rI|ZVWQ>0O@Z3|Zx(0#TIx0%JVAzsbLSR)9 z5u=&VwDPvAVjES)^=dd}>xQ*|oi{qp3w@0m)`jIJwH+K6C9bYrLrTv#ukE0iC|Py& z5mI)({%8lq#K@*(v_G4pGnu1{F77G1cwEty(8TdJZ(^G^`c^EY=zKBO4vLAA#lJ}2 z_bB=`(id)FRp&dmc3@PbvO2^UEj=ILyMyDR#L*p4U+*)r>p5RNg;X_vZ=c$EB4hOX zreqxZc4-i(t)rrp=WCIW;;WmHcD~pc7v$AIIPvLrpq(2YA(&C|6(B2Oje$gUOhC3T zsXROMHKGEXA-3oo5vO^;YaV=>he*vgc+EFH%{Kwfhpr(-E-m{eTGa$ux@Oe2a!y(N zTO^GwB^eY@C|P`GPa4~Q3L+?=V64lG5-}EQ?(EQzl)B%zmb;YVCYJc zvnvpoF2RgKOnA7H9KneVVIa4`6s+PHnF-^{O)){) z>NQ5NdhJS8uLB~_h_C0_`DC60--+4-qjp!QTMwe8ESx24#YVW$RXo9g2?fbdaWUcz zVxV_MfTB$sli@rkB^9q&(`T|JHROFp=2I+Dj%+bO4sO|h2OK~4x6=lxDEG9z?Su}E zk~(0u4zSoEO0CAK)v#JbHZp8JGB_hyR8BH|~Mo*`YJE^W{dY`1xk9_ zAw-nDqoYEjq>xi4D-A~j-xma=w4>d>+*VM0&1#AtwjE2V-%YVtM%l$r)OYC1Ti#)k*o zD&Pg;C{7XEn*Utld(^_#QV}8B^d;r-U1@1-3CN(Z0!wG` z3XDfq#Kg%Wn(qt5BN}SRD-bFg;3(;T;E-A`D)AHTGj(Mq%_>@9j^ps)9bZ7J zSl4orpRe%DgnJz`jbi<3DS1qH%Rqa*GHha;Vmu=k!b&+ihxEvVg6^0GF~0CQQ%hDM zmKcL*&?Czx)+wEx$zGXouUn>3lwUpRWMbNW=&(<2g@rvbE#lm1+GKAQsK=SbAm9ss zVj+&Z3MaS(M;PF04Q=nrMARDCJxQM3NOAj+DtU}pTO!J&5n!v0!P{yzJrP-^N2^T$ zOKl1csr|q%@L3>^;*Y3>Jxz^R%Px&A9T^l>5QW7fMad`xpryhCf&*sAQ(4>==TCfLytkw%J)k7~J^DG~WTRb*j+nFKz0XER1K1FYP z(TN^lQ{6t$u5O01WYJ#OK+6j|I?9V&A0EdrdkbAC?1^Cz<;!WPOHXioJwu*9XsM!%SdcKCij*O83 zef;ig@4oBU9oHQd9bv>Ppv)aD85q*RZ8e5>b4Tst#_iyS_HIY)+Q#kIhIVR4?9t`~ z(fF6fmV^unD`cnBa74){1fZpV!XpGZIZ}DteA989t^tJPjp(o!KT*;@0ll9QB5EbK zJmnfuzS}@Z+K3K#LFFEgr2Qi<)b_+Ni1!8A)pVg7zQlArr|USQ9WXB7+}WW)sT!kG zyCQW+0VCk@hw0B4{^DHS}(O;Q(DDHsnH$kgXl5I!J%U@Sk0QK!vB9sRd%?#f-4Ertp9u zW{YC)ZcOZh+Yv=65?{7TV_QK41r*SS!1#`Ti9GnF?-)XgIU+1bUt;)6SD*|k3>byM zRT#DvtZT?)wSZg5cb>?Qeuoh49Hqca%jKW=B}2m2rvE3fOz0wi?;ecipY9yot%9}{ zew-EkExJzBP5jlBG18t3 zFL!Xg=Tz-j^U0<5H^TmQL=+>xlo!5q}X=}Yu_p3&&?+yy%wu!gq zB@@;Kk86;1y`y=5r=$O#M2N?8n<98~qH%p7KsDbgsP7b9t`LN5r`>Y$TEW$IbAYjN z51oAr-Dshfx&TPM1ZKHy@@Bpj4KN;WS#$1KTW(T=zU;Rj*>4PF^P8`}e-#oExtu;X zO$iKt61UfCw&OGi0qbg@y!mts1YU?AKB%D4_F+JYnmx8M0EFL*oPw zc)<4A1bD!Xj0S4VM;^W*&&6tfmL1jlxD!E`a z(Tf!t@?~WZ=TS@K?bSW)@a`;7$d@HH&Jne0CG|})!H_d%e55;a4TWcqeb-IH@-|1`5r(f&6be#S5eQ{<^5YfnCPd=(%d$hbiRvp^K(9ODe$!|@ z)QMHi1cY3e;v!vWi5d8;E-ld5j{J9 zX}#7F0Re@tas`CklRy!Lm;qYp8Q}rhzgteKwnWsP6JjMe1qReayP}u(ILHA%#CZ7% zw$fe2#E75DrfYitCYNZe>&7Y?;fsHNOFY=?jkVEbi}6g;vgJ&@S#qX`>^RfJtXP1c z2Z$`#I5{u|=E~bg2yU??Wxg-x@XeUZa9-p%=@x@%wD>s(lct4s`=HH2ffv-JT!f~#6x1nI*t}E+(DG^+MoQux+ z1FL0p-Zk>*$GY#D_BqWpFP-J8*ADA?E{_zY;u%z|B%P>`G8u7&^DrUgK5U6jzzb!1 zFt@MD<`Wq>XQa;Rvb;8j9WKaQ#|Z23S@gH)wXQBCkdI3&+PUF@c2_$$lYq|6u~_FG zOL=hkwcZ&lkb6cn*1z)gwZ4si43J-A6xt_I_{tmbuJ+920sS)#uwIJD6RR`Phkg{V z?zO%GY>>x*8QO1t^_M(RfFAO+ulNPZ!mT1OwdCVf)AZs6x%x9%6G1MnEwGOE9cxS4 zU+bnx1378NN4r=hwwo)v$y{Mt%WY~yxpmseEvk8(s_V$2Rew#D{ngQb{WV&cCnlHM z$>o^jUeW)6`X1PR2WN?iW_F?(Cei6V9jKRs?cqrKwj6k>zQj5X$Uu|_v7`E1Xvd4q z@qK_PYNQ{=t+oAQ{w~_H!8HoEQZ;wF3iJ$s$>->CN67T2!5|0qV6YxvmAF5}FGfoC zbG>juUB>58)^X#;67}1E1gaXz1037!3Mgwt$Z0)U4io6Oac7G9Jpr2*(iEZE_^+0B zWTn)!z3A2cgcG?$PVnh(7w(R;8tAtLc~8rpY-&U?EN==BiTCPA@*X@i9!uf#OUL3J z^Xv}95%`=kw+hBHR#*8``JAZ|Jo$E=E8o1nd`oZ!tnyHJuRBYB^%u$a5ne8XK!d(~ zG?vOla?NGMxEOmn41}k<=cNRF_F5+=#8ZzX@dR9!Cdq#KRy(E<1w37?g+wcO*S#e& z1)N>4ii_0v^69?qnl|0ZkpT2xXy*%k!HytW}&{xL_>F_@dt23UTsc$W25)bI=+63umsg@Pa@N`ElBG6OI2I;Eo z#a#RF7Y@+DB^vN8Z>p-Pv&fdVYpxaFso?D<`22HQVg6~g>H)?n%R^2tcAE2k{(Bv7ob z`5%bz)_5dbd`EmT3%_ryhhR%uk3r}d4?-@aVR9$Ua_f$jD{Bi)F^26& z)ZhBg-6ws2v5v;cqe`vQr&ba3HzWQ~9o`W)jTU6^yX7@n4=$|Is35E^=##a~bR9Dd zGmLiQ1JyPq`T3m5-FgGD*SmuT)0o50=iP@naOA3oYCXGb1W&nAU&LD=h5BZ9H&maW z)w}jre-b%+#$kF0B!%$^R!-19tby`j0{0N$h+5T{6!nvhh7Y zFOka&ykziOHE^j99Crdlk2G9;u;U3QF5Z6K9fn_2(0;;Ld~@ODes-(kOMP~80XCli RWvTn#Yk}V1{6BJutxm0*WpDrh delta 9383 zcmV;YBv{+>O6N(CIRQD5J3fC&liW7az4x!E2+y@z<5GkVbCV_6;|NLCsO9i?W3Y>; zD!93cAY0w~^&b{4Adnz3QGHOT8p!uv0wl37Rrz##Y&v%Pjy>`t+qKXCIhMH)yPxIH z{?gR1vV7eIV z))wE}yl3a8d!~6ep<=A5-fr;U^&QKrn7cw-W!fZ2g7+>$m){FTaT@Al&zm}yL}?C@ z6a|T-C?JxuB$1RQOmZw6!H7&61JTvbB&HgA?0J{hXPQSE0a6{04efzxM3Cx2=eodT zg4aYYi3n0%=v)E6`ka62mqY}qE_5!~uXB|@(YO>MKt;`Zg@CEryz99pp%KnFm{v4h zPh?aOkgUt`CBC)}vD|)DJZ;ystm~g~R9}N+^y*^Uxv^lTPz4HS6K6_S5*=K^nmAK1 zlIZG<#LnLXi81_5Ows1cpppi@4jnC~1`&x=1R649bWeX1^WcBiaXVHK6bD0fgpN}X z^20HxO;Y;9F{mw4nTbULtAcION@N17j0ut!JU>$@6=svLsq-?0RN=zJ(;H8FL5s~- zo;|fxMA6#j$9fX>81RF+#na1oZhw>VPV>4sL8pYL19Yq z&GaWolnNB2oRdy_{M$(v^e;+md{Jo?$7ZD=+LDuVQiFfOlnKvf`_X(G@{Z9d6>5u= zu^%a{3KpiNw-dJ-iQ;?S=1+`VhEBL3ao?0&kROvlg-N@-7VS9tM&g}-LF&i6qmoKU zh%RID<)Do=E-;9+FPZG(Wk1t~L1)wkdEYg8aa7TLLm{S0~VglFAd(%EO7}DXHZf$>k~O<&(SwA%DRHFLOm`y5Ky~scf2#MkN6uqKw|I zv2hUXKwDO&>;al{my?=E${wU)vIpDIm^36zMjy;s0Rdu}ld}BpNlU*b3`AGgv?OZ) zL?p;-+eQdT#_JP}N&-NlnzqDB17IR{COc&?AYI$g(`Yatofz)tsc<~ugaRo4c_g)2 zrGGg zAf+`r78FdmEO{U>2A4Zjd0!2ra>O7&s(&XM&l60-E83vr2!Y5Lt+op+xu{q#3(Yf= z-AzgJ$g+x6<6Z_lGf1&1cFd2L>$Gf!O{tdU4_7H!1VJG#vivdUr5Zh>tJvtEgqRau z@{0F#6^CF`$Rzg~C7S_~qD>0GGXgS{{6QHAL}slMLbHX}2+4A!=WR)Q=SB=bCx4rl zx%#*$bQ1=o`<^hVu%H0Sre>tPQ}zKG`PMuzDwhU9gafgsrvU&_OKVZJV1p`YW5d?2 z2I`bpU_LBmGH7EeXo@N_dL`^)0s2n1beih+o^^FIlqILoyVx-O(QKjPM_OmZM~5ZO zLv_zm*bOjD{J@?QSS2bvI3bsKyEmotl2G*;M z(S0}Q#T7-CoYO8vSZt{IiP#}tBO&V5s!%lT){4-M$<`MwuQdXsn%7q1xkBqI=f$LG zs=z|hc?l_+D6mFRjihym$nZh;+edyzSsDZIeE&uM)~Wgno)^0KMV8Fh5P zfr8{Ujj)DEE<_fn>@sz1K^r8VK@cHv?O6}6IFla*>9bV!5RJVl0L`_rA)xO#jf%r0 zC6T3sX$B(A0MI<<<-pV|LPniuUn4>4j>ZR#z&!p#nj4Vj5J{EOSbwBFK=b{=XtkII z07O4nM_c->0U;vt(Ak1oQ`YtEn-Ba~4xDjO16-52Px5&F+%%=x7Jww0AR)2?o$TO~ z{hIg3%crpuj3b*N9@*P@>jRGIEwJ%n3_cKPfTX?anzk8wJkJOS5m8*4AtAEYbh6hz z+576A6$P))c(NJdk$?UE{=;8*9%b!Qe4#LW0SLNgXlbwG4DrbRsAcQ4J~A4gjq)fz z+J^|FiEB30+8Fh9p|pxt&Smi3%1Fb^-X4^MIyb$5K=l`cz;=R@HrgY3KKHKJZ%%8 zvHF_@mbeoy)gnOEb{bnIH10J%sM3YD8AsSVzq`D|mMlM|KTR7QRA}-d`_B0J(x)g` z=#WCP`w`FUyi!wGsq8CwNTba=X(k_kg1~|evFqX2 z)jy^n!rXhiTYr~G*(2l$#fpY)a_W6o*E~;%Wm9o%J5jS|!iKnqI<>5skU{3Q7#}H- zSL3!>(?8TT-d9m;I{wS@;*GsG2wA6lSdaem7e;BiYIBd|f76URwN2#!J>oAP?`X}v z5#SSj`a~f4^vNe6WfPc7Uz*SF2^7-WqqrmWGwz^%#(y3AHmG;OG((U?GsdUmB#$`c z5w8c)@FFiiH$pE{$;M+fde#z%q_s!!RgDy?d2vSg zd@PVh_*X0Pstq;XA+`zOQL!>GIOq^%lZAM-AU(EoI^8+Mm&t~G*F0T8J2JrRYKP?u zNLVz!PJeo)0C{rByFz^Dp~D=aoHO2IR)@o?V<=QdzelmZf{-xpxoo zDr~M^K}djWSSFhbOHRQI4KgQtzF=YQ?ag9a%kv+aaVz*~pf%{xAZ$NM#8WCI%Ai7+#8t7#& zQ2m73qkC$D)Nk7YijPe>=su7;*RrQfwyfuCVloG+%)ysA_%g*f-QvyD>`49U)f~hL` z5B>E+(~^7Qyreh5xvdDyJne~7Pu?8Xwjw+|wWrTdLDOB^(!d`iTmc97_T$vKRjn$`Y7H);P=CubGErh zAWpt8!0cNq>~ZYSf_oh&1e4t)UMuLaYt)e#(##W5dMMQa zTrVJH5rwqMpI<;{BnfKB&s7~!{1QqwRY*@VoOzwe^7%nN)fZ3&2!q<?amGOJ(W6tJ0@rDrh#Qy`ma!!5a7q!hay&FIN!$ z;xUSJ4+Z4mb?uE%TSM@4FFEApRqd8gn?;4U8{c+6k8%w z{v=DpGOD_Yz1pPy`#SEwH-E{^hn~*ve(F~A2Y3Y;lKg|WhRllmbNac>KmUvqjSIG? zJ+0vCI;7`8PJ|a~?)XKd%}jh$Z)52xSrS-{3fhGiN0KF>^^%t1mydTUUFk33NDvgb z`~6)yxdaD^$DVM(j(MPKixjbmf@i)VqHB@CsoNz1M{gt(Duv(A!v96c*S5?e`T0i1Jb)9?_(ypSy ze_yHmcToE8sJ30}@2{!$;TFY~{D;8sIJsE>wKledfi}Y)(P#*^JhrsFif4aS%@e5B zwmewXO#9@j38?QpRe#mjd4XzUp;?!tL(fY!*uhvvyNAFB28r<-?bo0n;ZP_0Iztz7 zeV5`+^u>S5TphxYo%kEgIBTZJ=f<3-sPk-tZ4!v+}p^%`x0__dx z*A>y*k237$*i@DDd3jmDys+Lg4`B36@ZREit3!74S1oR)tEhm?bMBgYtN+w1CvJ{y zEA@?EMRQb`F-TiW0oIu%bILTEOlJYx_$T?tURhWm+04x|wTBE&`Vwsz)Zcq!{!aef zeN{hi){nsU6@Pd*yLar7m8w&)d)FL?N;c?BQkAVfJXe0K_rv)t4IHL*tYQ18!X&2& zyD>+D2|Wm5j{XTG%(oj;Xd?D9Vm}eQGhsV*Ci$HSyJ1gyQc=7bTWcLVtCP>v*Jybi zwldz2bK_s;V2ihuRo^`&Z z$lAV>4yh&oU>532o5ol6H~ITU-CP42F1OSC_S@GRpr{~~C{x7V!v8jvg!9a@xH+?$ z*9Ci8lTV>0&3~w$H;dVrIhMu85C3Q4RnY;sImJyW{A04!g{HXC9d9UC>iU)4m;OWk z=4!>1Dt{}+@^J0aQbd}vOuvqQnBW^Tc^(oPr`AU){fLEx{CM8!&D=n;RAimyzybuHZS~}^33kL{09DbW{Sw?^OFhQTX6q_GLss?oEaw; zNAJ|!$@DIXo9UDB{h2-*i3xl*m0ESUoW$|*t~~k^>Jp}Rq#5# zjp5e|n`qvufSbDB1pjOIi4E!^*iWqg-gFNNunY``GHNz3Bz3)CAWM$bL)AV*?~{&L zD!BSlzJcm_$tK08q5Ax+U*(Sb3Yqq=n(uXCJko!&Wt7@Z4I=zjt#U7}LzJXnI#bRO)^3B>4Jh(Z@oDV`*C zOe8^RIwo@G1Y&e9M4=0)6iw1u4QSz&&ePD?&Iyd`klu!K1-uO^FsifM>p)Oj-|Mgg zqoT_Cm|%kFd`$4p4UI`n!zDPOX$WPv4nv6CekIC$g@kUuVx&m5>VyF0;x()>qkmL{ z69N{DR*J@~ram>G%9_Gag4C~j0l~7UFMYBkeeVk>s2uz<+A0jY0{!3KrB>ktqFyVxtI_8U?n{P?s#yM5h}T zcYZiVaGJNFiP}>(J2(y_&U;3Ep%ny_RS{TCq3+Y7>x$`0tz94zqpvP*rAbeBx9-*=PQPm9y)HYEl<$_a7W(9#& zWdueup=lLtS794f*3D=%W#@9VT`(Hwg}DKZ>cVmf+76CGiR&BFh|=?=X*(zeC97{O zBFfG;7ww=JjBH9qhqD=-$$x@ebcsOGC2&PoLK7$0yumhY%!O7&(fLlR9TbC-rN2nQ z_XzzOnY*&6s`GVOJ1`2VtPk-eOV0=R?%+6-IJpBF>U|=+p7YgHL{}dLIpTOT+ulKr+J`j9zvQ2q~>e7=IfB=>xkxk*W|@fS@sQD)do4bcGR}=k#YQ6 zq>LRU6%kCA{2>@g(P%sO%%*yE`sN z1@4@{xXyv5bAWUX2!EaI%xcyV$}k%QXiiaa<)(ZjJv|Ko6jg9cO%Q~Hry-*B7YhHy z7l=%96K5{KISb5e1#zANoT0$XO;Beg;Q0vDOayTb0-Sxo%sV(2GGq0Co0xZjNSyw7 zu`5;Iy4`gbh>Yn^=b#a^Cy1vzMMN|yvpm53PJCvEJf{*j?m9A}PwF2)Jp)9afORLlM0aup zx>J;%*N+qhY3_^2z=+NxC959&oO+CA6k)={pX7*6Y)AvO4W{6gpvX)ZUv7$t$UYNZ z>}T(7evZ7tdVhP0-uR#|BtT0!I7ilrMYzyaBEf+P1<6lw81aTM20J4_(Wb>@xH#q| zlTW|tGu6l$j(wKT2TYP2* zhR8-v%|{MrBuC{WCub!O{dZB1(x@$BZz2wPe+HMq=>T_D->8P#Bfhrpu0dk#VKN2^Pfk2XII)f zDl%l7zGMR5%T>mXfC`E$5R0S*0;pjMqj8fxsehkoVO-TE1||Egql$j(V99S?R8X`z zVW4^f_;NZpqSg>9OJ(e;G_lH(vPRioog7`4Ok0r*twlzxLZ+=hhE^UU)*O>o8w2Z% z<137%MP3w?p|%;}YnkwE0Ppz$T47x~=Eutwp4srAW4002ua@Nx>26u*pjVa+#wjkc z{4wXHn#v-2WJ5uBYy*rhV$Rf(SA_MrAb%S4$g#mXWwJBbD;pkk%QlMmMOd2Qugi!o z%8&-M!phkDGKj*8x=SUnKLxilrBVk=wI$)h8v(A`7`&|}JD<5KxX@}7z)_omV`|^2 zOFkBer}!PU)T5~htDBXvqoac23i6A1U?drh0JL;?L~x*tv}tQsXo!xowS-Pk>woou z+E^Eeer!nS%n;;(A*3Z-!1}HD;w|!$qQHV8+;XD%BFBu_{B(f<^(lJWi%xbtn(FqR zb#*h8B}aSV0v#{hXp|R6bDS8b>^AeEuqT!Q%9ry{siWkGZWt8ofr&$jKd@)2vV_30 zG6JP(uKuOU8WgN-;!xu1KVJ7#^?wlt%a4&LeZg2k6;~KoS;gRFNgKj!P(V1oC2X|p zyFo8QFS4}f?&4v8Tq~4oL(NZ=0ns%Kc4Ul<=+pPQ26wq)_qTRT_M`}}fHwEERA5Yp zwy_i5JB!*iOWQ9C?UY6Ak)`d9h4#fFcEpOD(6CmU!cC z;C&%><9z6bFPLuKbon9L0qeTMof{gJswq0PFH*-8C;~ozoc==LzxV=?EkDyGk7yC= zTF0Fm8j~Uv+`y5K@$(*Ot$!-0IfZCPk?J3y`a#q_v@P`D=6TGvdFqCFc&j{Wlf3Bj z;{#quS-UbSEGE*l^Cf!(F$?91T5&OPTCoZirI_M#xbM6CNJ!rZaJ9zZm>MBT-XLvs zHyim7gTV0xxPuG&oG{X-j2#sf6j#XdvZ>D*Y24Yt0lHRf%!QgDTYq~}b(H=>;lKC- zks8UgBMZW8_l$70rtpX$WyeH{7H}!^E)W^h?+BuurxclHc>D*yOhekb^#4Sb zW<08^*wFmb6>Pgz(0{hV_X{Jmvb43Mp~JRCez1;kDuA|jM06NJgf@}2Kp?s8sVx7- zafB*--S`fSy20P%hmj84|F?tVZg4mL5rtPT=sNs{au*;`c72E7E))y8U|OTr?M552 zwypC*-K`q6p(FXN!}s6VXYM=q-K{~+ZFlw+cj0T}AfCYGa)01Oa@K8X0M&fSnz>?a zxls+WopvaxTh>;$Zvn>EwOsC{T%(0v>j5Cm8j$7ky_@-xD8P8UbjrJOYPnbn`m*1A z;J9oET&+HihB@4|Kr;D0qko{+jp4%O`jWBnDXt&Y=;( zx1p1UHtknZZ+}+@4Q=2KM4f}L6V1SPUDGy0j~5yNE2S|ECq+rs41wgZX?slPldO-7rfH*AvUWs7`NQ4Uk9b98y;4#EF$HyK1Ai@HGi0qbg~kL;t%_IE_G>8} z6w&#ZKT+$k8M0EFLvex!y5Ip6j22ZNIe1dphV@#8F9(%{F#0yf+q{L>TOv?vMJs}! zob^p+uJ#mZPFMs%e=HlEYsU*Oi;mS~mnI20GE%^DWyfCCg9Auc!@%9I@k(?lgIkeP{f9=@ZN&Qg+mNrB-Vuda*)7zN`#z9<{l;%(rJ9-JJyr`Le{~9MQ{4Qr{FC3^`-R zBi&Jp90L2Qkx0;yDH7>L51&MfK!NJLbrkJ3M}ObtoL-@56r`*p2+H~(yer;Kh$QHj zV@Gh4)K_wVL3iT)rqOz+6RVmD2)QuDAzf&TlsDACKzSWQ5l=2o*YL12*aB7U;`uMj zi#Hp`siCoJcf9C`FQ30KO3zxGuM~AiMC5)gJ^h>m!BJ<_ebUB`})|>hT^tZ9FM|=F;AOQFD@Og z9|(TUn6(#ka7?pjgr1$UUhBw!h{9L30)Ikn8eoJ%%mS_Stni5JU#-Ziwj|V^6JjMe z1xD0lyP{Xe+N%LS!g%=#w$fe2V8qX5(=~l~7YZ8dyRnKU_!8cLf(Hk^adx6?Fqm8rBa_dd;$_~+3Mo?& zS2zz7Lhi$srbl!8s%$>zgAhEgv$`y=&0&uV^42q=dVG=nEq1M|3kl@o5=(Y& zxM$tf&dnsCb8{@!xrb66)_kpZ77OH_6;1W8{P|knMh3{QF^cSyEPU04V1HM8X7hml z*#=ZEW#pOHne0P9$XEATUja7AW5A5;x48OCo+v;MMcP;V0%hS=5tv%?iK=OO`GQ>i znXHK*7uOb4NBfSqCF`$s)24x(wByk(R*CKA%55@Nl-6>a+E8x2HgZd99;fPhvSig? zQe}Vfbbmn$i^Sv#H@Sk6+E4V>5c_<2hs>aPT5-p@r>0~{!}4nssvB|zRs1ugTDNo;0##hk?=ux zmg*15_W@omLqLPRe1EW(${@MsvSM6}Jsk!j(%lPEf<6bWlM@oCN0N9VE-RB{KSQe> z(}*ITuGT`56@u&D5}6{-u2;n&HNJejYrCdRcXA{EeH-J5F5Uh9F5Mjo23?WiL^sB9 z@q!(}t;q>2pkuZH(L?ssHnmxvbgi2n73iyHMRfR|hS!-mUVm6p=x3>y z70&Q9pzr z#wp8VPA_(b^KSloA~na#0xaWsvl{kA|DGWIv;+(iUm^{3UN0sVxH_pp9ohiGr^pG; z*zPa&_l@OgnSZMjtwCRZ`q$*G$RE{kx7ypJcG66X4QK2i@-utoTNxS?zL(+q34N!7 z^w|j?#vpyPgSc@XhwEPWfkFYo%lx_fXs%8%ICYk&bNbXOVt@W-5X2bA^)EW*xhOC z-v@OMY)t6&_sIgE$j0{oy+l4Q@RGrA)gY8UaNG$HJ< - mlx.core.Device — MLX 0.9.0 documentation + mlx.core.Device — MLX 0.10.0 documentation @@ -36,7 +36,7 @@ - + @@ -131,8 +131,8 @@ - MLX 0.9.0 documentation - Home - + MLX 0.10.0 documentation - Home + @@ -286,6 +286,7 @@
    • mlx.core.erf
    • mlx.core.erfinv
    • mlx.core.exp
    • +
    • mlx.core.expm1
    • mlx.core.expand_dims
    • mlx.core.eye
    • mlx.core.flatten
    • @@ -318,6 +319,7 @@
    • mlx.core.max
    • mlx.core.maximum
    • mlx.core.mean
    • +
    • mlx.core.meshgrid
    • mlx.core.min
    • mlx.core.minimum
    • mlx.core.moveaxis
    • @@ -352,6 +354,7 @@
    • mlx.core.square
    • mlx.core.squeeze
    • mlx.core.stack
    • +
    • mlx.core.std
    • mlx.core.stop_gradient
    • mlx.core.subtract
    • mlx.core.sum
    • @@ -379,6 +382,7 @@
    • mlx.core.random.gumbel
    • mlx.core.random.key
    • mlx.core.random.normal
    • +
    • mlx.core.random.multivariate_normal
    • mlx.core.random.randint
    • mlx.core.random.seed
    • mlx.core.random.split
    • @@ -432,6 +436,8 @@
    • mlx.core.metal.get_cache_memory
    • mlx.core.metal.set_memory_limit
    • mlx.core.metal.set_cache_limit
    • +
    • mlx.core.metal.start_capture
    • +
    • mlx.core.metal.stop_capture
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
      @@ -782,13 +788,14 @@ document.write(` dimension.

      For example, an audio signal would be 3D with 1 spatial dimension, an image 4D with 2 and so on and so forth.

      -

      There are two upsampling algorithms implemented nearest neighbor upsampling -and linear interpolation. Both can be applied to any number of spatial -dimensions and the linear interpolation will be bilinear, trilinear etc -when applied to more than one spatial dimension.

      +

      There are three upsampling algorithms implemented nearest neighbor upsampling, +linear interpolation, and cubic interpolation. All can be applied to any number +of spatial dimensions. The linear interpolation will be bilinear, trilinear etc +when applied to more than one spatial dimension. And cubic interpolation will be +bicubic when there are 2 spatial dimensions.

      Note

      -

      When using one of the linear interpolation modes the align_corners +

      When using one of the linear or cubic interpolation modes the align_corners argument changes how the corners are treated in the input image. If align_corners=True then the top and left edge of the input and output will be matching as will the bottom right edge.

      @@ -800,10 +807,10 @@ output will be matching as will the bottom right edge.

      If a float is provided, it is the multiplier for all spatial dimensions. Otherwise, the number of scale factors provided must match the number of spatial dimensions.

      -
    • mode (str, optional) – The upsampling algorithm, either "nearest" or -"linear". Default: "nearest".

    • +
    • mode (str, optional) – The upsampling algorithm, either "nearest", +"linear" or "cubic". Default: "nearest".

    • align_corners (bool, optional) – Changes the way the corners are treated -during "linear" upsampling. See the note above and the +during "linear" and "cubic" upsampling. See the note above and the examples below for more details. Default: False.

    diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.constant.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.constant.html index d1f420cb8..982cf049c 100644 --- a/docs/build/html/python/nn/_autosummary/mlx.nn.init.constant.html +++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.constant.html @@ -8,7 +8,7 @@ - mlx.nn.init.constant — MLX 0.9.0 documentation + mlx.nn.init.constant — MLX 0.10.0 documentation @@ -36,7 +36,7 @@ - + @@ -131,8 +131,8 @@ - MLX 0.9.0 documentation - Home - + MLX 0.10.0 documentation - Home + @@ -286,6 +286,7 @@
  • mlx.core.erf
  • mlx.core.erfinv
  • mlx.core.exp
  • +
  • mlx.core.expm1
  • mlx.core.expand_dims
  • mlx.core.eye
  • mlx.core.flatten
  • @@ -318,6 +319,7 @@
  • mlx.core.max
  • mlx.core.maximum
  • mlx.core.mean
  • +
  • mlx.core.meshgrid
  • mlx.core.min
  • mlx.core.minimum
  • mlx.core.moveaxis
  • @@ -352,6 +354,7 @@
  • mlx.core.square
  • mlx.core.squeeze
  • mlx.core.stack
  • +
  • mlx.core.std
  • mlx.core.stop_gradient
  • mlx.core.subtract
  • mlx.core.sum
  • @@ -379,6 +382,7 @@
  • mlx.core.random.gumbel
  • mlx.core.random.key
  • mlx.core.random.normal
  • +
  • mlx.core.random.multivariate_normal
  • mlx.core.random.randint
  • mlx.core.random.seed
  • mlx.core.random.split
  • @@ -432,6 +436,8 @@
  • mlx.core.metal.get_cache_memory
  • mlx.core.metal.set_memory_limit
  • mlx.core.metal.set_cache_limit
  • +
  • mlx.core.metal.start_capture
  • +
  • mlx.core.metal.stop_capture
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
      @@ -894,102 +900,108 @@ document.write(`

      exp(a, /, *[, stream])

      Element-wise exponential.

      -

      expand_dims(a, /, axis, *[, stream])

      +

      expm1(a, /, *[, stream])

      +

      Element-wise exponential minus 1.

      + +

      expand_dims(a, /, axis, *[, stream])

      Add a size one dimension at the given axis.

      -

      eye(n[, m, k, dtype, stream])

      +

      eye(n[, m, k, dtype, stream])

      Create an identity matrix or a general diagonal matrix.

      -

      flatten(a, /[, start_axis, end_axis, stream])

      +

      flatten(a, /[, start_axis, end_axis, stream])

      Flatten an array.

      -

      floor(a, /, *[, stream])

      +

      floor(a, /, *[, stream])

      Element-wise floor.

      -

      floor_divide(a, b[, stream])

      +

      floor_divide(a, b[, stream])

      Element-wise integer division.

      -

      full(shape, vals[, dtype, stream])

      +

      full(shape, vals[, dtype, stream])

      Construct an array with the given value.

      -

      greater(a, b[, stream])

      +

      greater(a, b[, stream])

      Element-wise greater than.

      -

      greater_equal(a, b[, stream])

      +

      greater_equal(a, b[, stream])

      Element-wise greater or equal.

      -

      identity(n[, dtype, stream])

      +

      identity(n[, dtype, stream])

      Create a square identity matrix.

      -

      inner(a, b, /, *[, stream])

      +

      inner(a, b, /, *[, stream])

      Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.

      -

      isclose(a, b, /[, rtol, atol, equal_nan, stream])

      +

      isclose(a, b, /[, rtol, atol, equal_nan, stream])

      Returns a boolean array where two arrays are element-wise equal within a tolerance.

      -

      isinf(a[, stream])

      +

      isinf(a[, stream])

      Return a boolean array indicating which elements are +/- inifnity.

      -

      isnan(a[, stream])

      +

      isnan(a[, stream])

      Return a boolean array indicating which elements are NaN.

      -

      isneginf(a[, stream])

      +

      isneginf(a[, stream])

      Return a boolean array indicating which elements are negative infinity.

      -

      isposinf(a[, stream])

      +

      isposinf(a[, stream])

      Return a boolean array indicating which elements are positive infinity.

      -

      less(a, b[, stream])

      +

      less(a, b[, stream])

      Element-wise less than.

      -

      less_equal(a, b[, stream])

      +

      less_equal(a, b[, stream])

      Element-wise less than or equal.

      -

      linspace(start, stop[, num, dtype, stream])

      +

      linspace(start, stop[, num, dtype, stream])

      Generate num evenly spaced numbers over interval [start, stop].

      -

      load(file, /[, format, return_metadata, stream])

      +

      load(file, /[, format, return_metadata, stream])

      Load array(s) from a binary file.

      -

      log(a, /, *[, stream])

      +

      log(a, /, *[, stream])

      Element-wise natural logarithm.

      -

      log2(a, /, *[, stream])

      +

      log2(a, /, *[, stream])

      Element-wise base-2 logarithm.

      -

      log10(a, /, *[, stream])

      +

      log10(a, /, *[, stream])

      Element-wise base-10 logarithm.

      -

      log1p(a, /, *[, stream])

      +

      log1p(a, /, *[, stream])

      Element-wise natural log of one plus the array.

      -

      logaddexp(a, b, /, *[, stream])

      +

      logaddexp(a, b, /, *[, stream])

      Element-wise log-add-exp.

      -

      logical_not(a, /, *[, stream])

      +

      logical_not(a, /, *[, stream])

      Element-wise logical not.

      -

      logical_and(a, b, /, *[, stream])

      +

      logical_and(a, b, /, *[, stream])

      Element-wise logical and.

      -

      logical_or(a, b, /, *[, stream])

      +

      logical_or(a, b, /, *[, stream])

      Element-wise logical or.

      -

      logsumexp(a, /[, axis, keepdims, stream])

      +

      logsumexp(a, /[, axis, keepdims, stream])

      A log-sum-exp reduction over the given axes.

      -

      matmul(a, b, /, *[, stream])

      +

      matmul(a, b, /, *[, stream])

      Matrix multiplication.

      -

      max(a, /[, axis, keepdims, stream])

      +

      max(a, /[, axis, keepdims, stream])

      A max reduction over the given axes.

      -

      maximum(a, b, /, *[, stream])

      +

      maximum(a, b, /, *[, stream])

      Element-wise maximum.

      -

      mean(a, /[, axis, keepdims, stream])

      +

      mean(a, /[, axis, keepdims, stream])

      Compute the mean(s) over the given axes.

      +

      meshgrid(*arrays[, sparse, indexing, stream])

      +

      Generate multidimensional coordinate grids from 1-D coordinate arrays

      +

      min(a, /[, axis, keepdims, stream])

      A min reduction over the given axes.

      @@ -1092,61 +1104,64 @@ document.write(`

      stack(arrays[, axis, stream])

      Stacks the arrays along a new axis.

      -

      stop_gradient(a, /, *[, stream])

      +

      std(a, /[, axis, keepdims, ddof, stream])

      +

      Compute the standard deviation(s) over the given axes.

      + +

      stop_gradient(a, /, *[, stream])

      Stop gradients from being computed.

      -

      subtract(a, b[, stream])

      +

      subtract(a, b[, stream])

      Element-wise subtraction.

      -

      sum(a, /[, axis, keepdims, stream])

      +

      sum(a, /[, axis, keepdims, stream])

      Sum reduce the array over the given axes.

      -

      swapaxes(a, /, axis1, axis2, *[, stream])

      +

      swapaxes(a, /, axis1, axis2, *[, stream])

      Swap two axes of an array.

      -

      take(a, /, indices[, axis, stream])

      +

      take(a, /, indices[, axis, stream])

      Take elements along an axis.

      -

      take_along_axis(a, /, indices[, axis, stream])

      +

      take_along_axis(a, /, indices[, axis, stream])

      Take values along an axis at the specified indices.

      -

      tan(a, /, *[, stream])

      +

      tan(a, /, *[, stream])

      Element-wise tangent.

      -

      tanh(a, /, *[, stream])

      +

      tanh(a, /, *[, stream])

      Element-wise hyperbolic tangent.

      -

      tensordot(a, b, /[, axes, stream])

      +

      tensordot(a, b, /[, axes, stream])

      Compute the tensor dot product along the specified axes.

      -

      tile(a, reps, /, *[, stream])

      +

      tile(a, reps, /, *[, stream])

      Construct an array by repeating a the number of times given by reps.

      -

      topk(a, /, k[, axis, stream])

      +

      topk(a, /, k[, axis, stream])

      Returns the k largest elements from the input along a given axis.

      -

      transpose(a, /[, axes, stream])

      +

      transpose(a, /[, axes, stream])

      Transpose the dimensions of the array.

      -

      tri(n, m, k[, dtype, stream])

      +

      tri(n, m, k[, dtype, stream])

      An array with ones at and below the given diagonal and zeros elsewhere.

      -

      tril(x, k, *[, stream])

      +

      tril(x, k, *[, stream])

      Zeros the array above the given diagonal.

      -

      triu(x, k, *[, stream])

      +

      triu(x, k, *[, stream])

      Zeros the array below the given diagonal.

      -

      var(a, /[, axis, keepdims, ddof, stream])

      +

      var(a, /[, axis, keepdims, ddof, stream])

      Compute the variance(s) over the given axes.

      -

      where(condition, x, y, /, *[, stream])

      +

      where(condition, x, y, /, *[, stream])

      Select from x or y according to condition.

      -

      zeros(shape[, dtype, stream])

      +

      zeros(shape[, dtype, stream])

      Construct an array of zeros.

      -

      zeros_like(a, /, *[, stream])

      +

      zeros_like(a, /, *[, stream])

      An array of zeros like the input.

      diff --git a/docs/build/html/python/optimizers.html b/docs/build/html/python/optimizers.html index c89abaed5..b9ff5c6f6 100644 --- a/docs/build/html/python/optimizers.html +++ b/docs/build/html/python/optimizers.html @@ -8,7 +8,7 @@ - Optimizers — MLX 0.9.0 documentation + Optimizers — MLX 0.10.0 documentation @@ -36,7 +36,7 @@ - + @@ -131,8 +131,8 @@ - MLX 0.9.0 documentation - Home - + MLX 0.10.0 documentation - Home + @@ -286,6 +286,7 @@
    • mlx.core.erf
    • mlx.core.erfinv
    • mlx.core.exp
    • +
    • mlx.core.expm1
    • mlx.core.expand_dims
    • mlx.core.eye
    • mlx.core.flatten
    • @@ -318,6 +319,7 @@
    • mlx.core.max
    • mlx.core.maximum
    • mlx.core.mean
    • +
    • mlx.core.meshgrid
    • mlx.core.min
    • mlx.core.minimum
    • mlx.core.moveaxis
    • @@ -352,6 +354,7 @@
    • mlx.core.square
    • mlx.core.squeeze
    • mlx.core.stack
    • +
    • mlx.core.std
    • mlx.core.stop_gradient
    • mlx.core.subtract
    • mlx.core.sum
    • @@ -379,6 +382,7 @@
    • mlx.core.random.gumbel
    • mlx.core.random.key
    • mlx.core.random.normal
    • +
    • mlx.core.random.multivariate_normal
    • mlx.core.random.randint
    • mlx.core.random.seed
    • mlx.core.random.split
    • @@ -432,6 +436,8 @@
    • mlx.core.metal.get_cache_memory
    • mlx.core.metal.set_memory_limit
    • mlx.core.metal.set_cache_limit
    • +
    • mlx.core.metal.start_capture
    • +
    • mlx.core.metal.stop_capture
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
      diff --git a/docs/build/html/searchindex.js b/docs/build/html/searchindex.js index 896f41a7a..0db2c1d81 100644 --- a/docs/build/html/searchindex.js +++ b/docs/build/html/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["cpp/ops", "dev/extensions", "dev/metal_debugger", "examples/linear_regression", "examples/llama-inference", "examples/mlp", "index", "install", "python/_autosummary/mlx.core.Device", "python/_autosummary/mlx.core.Dtype", "python/_autosummary/mlx.core.DtypeCategory", "python/_autosummary/mlx.core.abs", "python/_autosummary/mlx.core.add", "python/_autosummary/mlx.core.all", "python/_autosummary/mlx.core.allclose", "python/_autosummary/mlx.core.any", "python/_autosummary/mlx.core.arange", "python/_autosummary/mlx.core.arccos", "python/_autosummary/mlx.core.arccosh", "python/_autosummary/mlx.core.arcsin", "python/_autosummary/mlx.core.arcsinh", "python/_autosummary/mlx.core.arctan", "python/_autosummary/mlx.core.arctanh", "python/_autosummary/mlx.core.argmax", "python/_autosummary/mlx.core.argmin", "python/_autosummary/mlx.core.argpartition", "python/_autosummary/mlx.core.argsort", "python/_autosummary/mlx.core.array", "python/_autosummary/mlx.core.array.T", "python/_autosummary/mlx.core.array.abs", "python/_autosummary/mlx.core.array.all", "python/_autosummary/mlx.core.array.any", "python/_autosummary/mlx.core.array.argmax", "python/_autosummary/mlx.core.array.argmin", "python/_autosummary/mlx.core.array.astype", "python/_autosummary/mlx.core.array.at", "python/_autosummary/mlx.core.array.cos", "python/_autosummary/mlx.core.array.cummax", "python/_autosummary/mlx.core.array.cummin", "python/_autosummary/mlx.core.array.cumprod", "python/_autosummary/mlx.core.array.cumsum", "python/_autosummary/mlx.core.array.diag", "python/_autosummary/mlx.core.array.diagonal", "python/_autosummary/mlx.core.array.dtype", "python/_autosummary/mlx.core.array.exp", "python/_autosummary/mlx.core.array.flatten", "python/_autosummary/mlx.core.array.item", "python/_autosummary/mlx.core.array.itemsize", "python/_autosummary/mlx.core.array.log", "python/_autosummary/mlx.core.array.log10", "python/_autosummary/mlx.core.array.log1p", "python/_autosummary/mlx.core.array.log2", "python/_autosummary/mlx.core.array.logsumexp", "python/_autosummary/mlx.core.array.max", "python/_autosummary/mlx.core.array.mean", "python/_autosummary/mlx.core.array.min", "python/_autosummary/mlx.core.array.moveaxis", "python/_autosummary/mlx.core.array.nbytes", "python/_autosummary/mlx.core.array.ndim", "python/_autosummary/mlx.core.array.prod", "python/_autosummary/mlx.core.array.reciprocal", "python/_autosummary/mlx.core.array.reshape", "python/_autosummary/mlx.core.array.round", "python/_autosummary/mlx.core.array.rsqrt", "python/_autosummary/mlx.core.array.shape", "python/_autosummary/mlx.core.array.sin", "python/_autosummary/mlx.core.array.size", "python/_autosummary/mlx.core.array.split", "python/_autosummary/mlx.core.array.sqrt", "python/_autosummary/mlx.core.array.square", "python/_autosummary/mlx.core.array.squeeze", "python/_autosummary/mlx.core.array.sum", "python/_autosummary/mlx.core.array.swapaxes", "python/_autosummary/mlx.core.array.tolist", "python/_autosummary/mlx.core.array.transpose", "python/_autosummary/mlx.core.array.var", "python/_autosummary/mlx.core.array_equal", "python/_autosummary/mlx.core.atleast_1d", "python/_autosummary/mlx.core.atleast_2d", "python/_autosummary/mlx.core.atleast_3d", "python/_autosummary/mlx.core.broadcast_to", "python/_autosummary/mlx.core.ceil", "python/_autosummary/mlx.core.clip", "python/_autosummary/mlx.core.compile", "python/_autosummary/mlx.core.concatenate", "python/_autosummary/mlx.core.conv1d", "python/_autosummary/mlx.core.conv2d", "python/_autosummary/mlx.core.conv_general", "python/_autosummary/mlx.core.convolve", "python/_autosummary/mlx.core.cos", "python/_autosummary/mlx.core.cosh", "python/_autosummary/mlx.core.cummax", "python/_autosummary/mlx.core.cummin", "python/_autosummary/mlx.core.cumprod", "python/_autosummary/mlx.core.cumsum", "python/_autosummary/mlx.core.default_device", "python/_autosummary/mlx.core.default_stream", "python/_autosummary/mlx.core.dequantize", "python/_autosummary/mlx.core.diag", "python/_autosummary/mlx.core.diagonal", "python/_autosummary/mlx.core.disable_compile", "python/_autosummary/mlx.core.divide", "python/_autosummary/mlx.core.divmod", "python/_autosummary/mlx.core.enable_compile", "python/_autosummary/mlx.core.equal", "python/_autosummary/mlx.core.erf", "python/_autosummary/mlx.core.erfinv", "python/_autosummary/mlx.core.eval", "python/_autosummary/mlx.core.exp", "python/_autosummary/mlx.core.expand_dims", "python/_autosummary/mlx.core.eye", "python/_autosummary/mlx.core.fast.layer_norm", "python/_autosummary/mlx.core.fast.rms_norm", "python/_autosummary/mlx.core.fast.rope", "python/_autosummary/mlx.core.fast.scaled_dot_product_attention", "python/_autosummary/mlx.core.fft.fft", "python/_autosummary/mlx.core.fft.fft2", "python/_autosummary/mlx.core.fft.fftn", "python/_autosummary/mlx.core.fft.ifft", "python/_autosummary/mlx.core.fft.ifft2", "python/_autosummary/mlx.core.fft.ifftn", "python/_autosummary/mlx.core.fft.irfft", "python/_autosummary/mlx.core.fft.irfft2", "python/_autosummary/mlx.core.fft.irfftn", "python/_autosummary/mlx.core.fft.rfft", "python/_autosummary/mlx.core.fft.rfft2", "python/_autosummary/mlx.core.fft.rfftn", "python/_autosummary/mlx.core.flatten", "python/_autosummary/mlx.core.floor", "python/_autosummary/mlx.core.floor_divide", "python/_autosummary/mlx.core.full", "python/_autosummary/mlx.core.grad", "python/_autosummary/mlx.core.greater", "python/_autosummary/mlx.core.greater_equal", "python/_autosummary/mlx.core.identity", "python/_autosummary/mlx.core.inner", "python/_autosummary/mlx.core.isclose", "python/_autosummary/mlx.core.isinf", "python/_autosummary/mlx.core.isnan", "python/_autosummary/mlx.core.isneginf", "python/_autosummary/mlx.core.isposinf", "python/_autosummary/mlx.core.issubdtype", "python/_autosummary/mlx.core.jvp", "python/_autosummary/mlx.core.less", "python/_autosummary/mlx.core.less_equal", "python/_autosummary/mlx.core.linalg.norm", "python/_autosummary/mlx.core.linalg.qr", "python/_autosummary/mlx.core.linspace", "python/_autosummary/mlx.core.load", "python/_autosummary/mlx.core.log", "python/_autosummary/mlx.core.log10", "python/_autosummary/mlx.core.log1p", "python/_autosummary/mlx.core.log2", "python/_autosummary/mlx.core.logaddexp", "python/_autosummary/mlx.core.logical_and", "python/_autosummary/mlx.core.logical_not", "python/_autosummary/mlx.core.logical_or", "python/_autosummary/mlx.core.logsumexp", "python/_autosummary/mlx.core.matmul", "python/_autosummary/mlx.core.max", "python/_autosummary/mlx.core.maximum", "python/_autosummary/mlx.core.mean", "python/_autosummary/mlx.core.metal.get_active_memory", "python/_autosummary/mlx.core.metal.get_cache_memory", "python/_autosummary/mlx.core.metal.get_peak_memory", "python/_autosummary/mlx.core.metal.is_available", "python/_autosummary/mlx.core.metal.set_cache_limit", "python/_autosummary/mlx.core.metal.set_memory_limit", "python/_autosummary/mlx.core.min", "python/_autosummary/mlx.core.minimum", "python/_autosummary/mlx.core.moveaxis", "python/_autosummary/mlx.core.multiply", "python/_autosummary/mlx.core.negative", "python/_autosummary/mlx.core.new_stream", "python/_autosummary/mlx.core.ones", "python/_autosummary/mlx.core.ones_like", "python/_autosummary/mlx.core.outer", "python/_autosummary/mlx.core.pad", "python/_autosummary/mlx.core.partition", "python/_autosummary/mlx.core.prod", "python/_autosummary/mlx.core.quantize", "python/_autosummary/mlx.core.quantized_matmul", "python/_autosummary/mlx.core.random.bernoulli", "python/_autosummary/mlx.core.random.categorical", "python/_autosummary/mlx.core.random.gumbel", "python/_autosummary/mlx.core.random.key", "python/_autosummary/mlx.core.random.normal", "python/_autosummary/mlx.core.random.randint", "python/_autosummary/mlx.core.random.seed", "python/_autosummary/mlx.core.random.split", "python/_autosummary/mlx.core.random.truncated_normal", "python/_autosummary/mlx.core.random.uniform", "python/_autosummary/mlx.core.reciprocal", "python/_autosummary/mlx.core.repeat", "python/_autosummary/mlx.core.reshape", "python/_autosummary/mlx.core.round", "python/_autosummary/mlx.core.rsqrt", "python/_autosummary/mlx.core.save", "python/_autosummary/mlx.core.save_gguf", "python/_autosummary/mlx.core.save_safetensors", "python/_autosummary/mlx.core.savez", "python/_autosummary/mlx.core.savez_compressed", "python/_autosummary/mlx.core.set_default_device", "python/_autosummary/mlx.core.set_default_stream", "python/_autosummary/mlx.core.sigmoid", "python/_autosummary/mlx.core.sign", "python/_autosummary/mlx.core.sin", "python/_autosummary/mlx.core.sinh", "python/_autosummary/mlx.core.softmax", "python/_autosummary/mlx.core.sort", "python/_autosummary/mlx.core.split", "python/_autosummary/mlx.core.sqrt", "python/_autosummary/mlx.core.square", "python/_autosummary/mlx.core.squeeze", "python/_autosummary/mlx.core.stack", "python/_autosummary/mlx.core.stop_gradient", "python/_autosummary/mlx.core.stream", "python/_autosummary/mlx.core.subtract", "python/_autosummary/mlx.core.sum", "python/_autosummary/mlx.core.swapaxes", "python/_autosummary/mlx.core.take", "python/_autosummary/mlx.core.take_along_axis", "python/_autosummary/mlx.core.tan", "python/_autosummary/mlx.core.tanh", "python/_autosummary/mlx.core.tensordot", "python/_autosummary/mlx.core.tile", "python/_autosummary/mlx.core.topk", "python/_autosummary/mlx.core.transpose", "python/_autosummary/mlx.core.tri", "python/_autosummary/mlx.core.tril", "python/_autosummary/mlx.core.triu", "python/_autosummary/mlx.core.value_and_grad", "python/_autosummary/mlx.core.var", "python/_autosummary/mlx.core.vjp", "python/_autosummary/mlx.core.vmap", "python/_autosummary/mlx.core.where", "python/_autosummary/mlx.core.zeros", "python/_autosummary/mlx.core.zeros_like", "python/_autosummary/mlx.nn.value_and_grad", "python/_autosummary/mlx.utils.tree_flatten", "python/_autosummary/mlx.utils.tree_map", "python/_autosummary/mlx.utils.tree_unflatten", "python/_autosummary/stream_class", "python/array", "python/data_types", "python/devices_and_streams", "python/fast", "python/fft", "python/linalg", "python/metal", "python/nn", "python/nn/_autosummary/mlx.nn.ALiBi", "python/nn/_autosummary/mlx.nn.AvgPool1d", "python/nn/_autosummary/mlx.nn.AvgPool2d", "python/nn/_autosummary/mlx.nn.BatchNorm", "python/nn/_autosummary/mlx.nn.Conv1d", "python/nn/_autosummary/mlx.nn.Conv2d", "python/nn/_autosummary/mlx.nn.Dropout", "python/nn/_autosummary/mlx.nn.Dropout2d", "python/nn/_autosummary/mlx.nn.Dropout3d", "python/nn/_autosummary/mlx.nn.Embedding", "python/nn/_autosummary/mlx.nn.GELU", "python/nn/_autosummary/mlx.nn.GRU", "python/nn/_autosummary/mlx.nn.GroupNorm", "python/nn/_autosummary/mlx.nn.InstanceNorm", "python/nn/_autosummary/mlx.nn.LSTM", "python/nn/_autosummary/mlx.nn.LayerNorm", "python/nn/_autosummary/mlx.nn.Linear", "python/nn/_autosummary/mlx.nn.MaxPool1d", "python/nn/_autosummary/mlx.nn.MaxPool2d", "python/nn/_autosummary/mlx.nn.Mish", "python/nn/_autosummary/mlx.nn.Module.apply", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules", "python/nn/_autosummary/mlx.nn.Module.children", "python/nn/_autosummary/mlx.nn.Module.eval", "python/nn/_autosummary/mlx.nn.Module.filter_and_map", "python/nn/_autosummary/mlx.nn.Module.freeze", "python/nn/_autosummary/mlx.nn.Module.leaf_modules", "python/nn/_autosummary/mlx.nn.Module.load_weights", "python/nn/_autosummary/mlx.nn.Module.modules", "python/nn/_autosummary/mlx.nn.Module.named_modules", "python/nn/_autosummary/mlx.nn.Module.parameters", "python/nn/_autosummary/mlx.nn.Module.save_weights", "python/nn/_autosummary/mlx.nn.Module.set_dtype", "python/nn/_autosummary/mlx.nn.Module.state", "python/nn/_autosummary/mlx.nn.Module.train", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters", "python/nn/_autosummary/mlx.nn.Module.training", "python/nn/_autosummary/mlx.nn.Module.unfreeze", "python/nn/_autosummary/mlx.nn.Module.update", "python/nn/_autosummary/mlx.nn.Module.update_modules", "python/nn/_autosummary/mlx.nn.MultiHeadAttention", "python/nn/_autosummary/mlx.nn.PReLU", "python/nn/_autosummary/mlx.nn.QuantizedLinear", "python/nn/_autosummary/mlx.nn.RMSNorm", "python/nn/_autosummary/mlx.nn.RNN", "python/nn/_autosummary/mlx.nn.ReLU", "python/nn/_autosummary/mlx.nn.RoPE", "python/nn/_autosummary/mlx.nn.SELU", "python/nn/_autosummary/mlx.nn.Sequential", "python/nn/_autosummary/mlx.nn.SiLU", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding", "python/nn/_autosummary/mlx.nn.Softshrink", "python/nn/_autosummary/mlx.nn.Step", "python/nn/_autosummary/mlx.nn.Transformer", "python/nn/_autosummary/mlx.nn.Upsample", "python/nn/_autosummary/mlx.nn.init.constant", "python/nn/_autosummary/mlx.nn.init.glorot_normal", "python/nn/_autosummary/mlx.nn.init.glorot_uniform", "python/nn/_autosummary/mlx.nn.init.he_normal", "python/nn/_autosummary/mlx.nn.init.he_uniform", "python/nn/_autosummary/mlx.nn.init.identity", "python/nn/_autosummary/mlx.nn.init.normal", "python/nn/_autosummary/mlx.nn.init.uniform", "python/nn/_autosummary_functions/mlx.nn.elu", "python/nn/_autosummary_functions/mlx.nn.gelu", "python/nn/_autosummary_functions/mlx.nn.gelu_approx", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx", "python/nn/_autosummary_functions/mlx.nn.glu", "python/nn/_autosummary_functions/mlx.nn.hardswish", "python/nn/_autosummary_functions/mlx.nn.leaky_relu", "python/nn/_autosummary_functions/mlx.nn.log_sigmoid", "python/nn/_autosummary_functions/mlx.nn.log_softmax", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss", "python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss", "python/nn/_autosummary_functions/mlx.nn.mish", "python/nn/_autosummary_functions/mlx.nn.prelu", "python/nn/_autosummary_functions/mlx.nn.relu", "python/nn/_autosummary_functions/mlx.nn.relu6", "python/nn/_autosummary_functions/mlx.nn.selu", "python/nn/_autosummary_functions/mlx.nn.sigmoid", "python/nn/_autosummary_functions/mlx.nn.silu", "python/nn/_autosummary_functions/mlx.nn.softmax", "python/nn/_autosummary_functions/mlx.nn.softplus", "python/nn/_autosummary_functions/mlx.nn.softshrink", "python/nn/_autosummary_functions/mlx.nn.step", "python/nn/_autosummary_functions/mlx.nn.tanh", "python/nn/functions", "python/nn/init", "python/nn/layers", "python/nn/losses", "python/nn/module", "python/ops", "python/optimizers", "python/optimizers/_autosummary/mlx.optimizers.AdaDelta", "python/optimizers/_autosummary/mlx.optimizers.Adafactor", "python/optimizers/_autosummary/mlx.optimizers.Adagrad", "python/optimizers/_autosummary/mlx.optimizers.Adam", "python/optimizers/_autosummary/mlx.optimizers.AdamW", "python/optimizers/_autosummary/mlx.optimizers.Adamax", "python/optimizers/_autosummary/mlx.optimizers.Lion", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.init", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.state", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.update", "python/optimizers/_autosummary/mlx.optimizers.RMSprop", "python/optimizers/_autosummary/mlx.optimizers.SGD", "python/optimizers/_autosummary/mlx.optimizers.cosine_decay", "python/optimizers/_autosummary/mlx.optimizers.exponential_decay", "python/optimizers/_autosummary/mlx.optimizers.join_schedules", "python/optimizers/_autosummary/mlx.optimizers.linear_schedule", "python/optimizers/_autosummary/mlx.optimizers.step_decay", "python/optimizers/common_optimizers", "python/optimizers/optimizer", "python/optimizers/schedulers", "python/random", "python/transforms", "python/tree_utils", "usage/compile", "usage/function_transforms", "usage/indexing", "usage/lazy_evaluation", "usage/numpy", "usage/quick_start", "usage/saving_and_loading", "usage/unified_memory", "usage/using_streams"], "filenames": ["cpp/ops.rst", "dev/extensions.rst", "dev/metal_debugger.rst", "examples/linear_regression.rst", "examples/llama-inference.rst", "examples/mlp.rst", "index.rst", "install.rst", "python/_autosummary/mlx.core.Device.rst", "python/_autosummary/mlx.core.Dtype.rst", "python/_autosummary/mlx.core.DtypeCategory.rst", "python/_autosummary/mlx.core.abs.rst", "python/_autosummary/mlx.core.add.rst", "python/_autosummary/mlx.core.all.rst", "python/_autosummary/mlx.core.allclose.rst", "python/_autosummary/mlx.core.any.rst", "python/_autosummary/mlx.core.arange.rst", "python/_autosummary/mlx.core.arccos.rst", "python/_autosummary/mlx.core.arccosh.rst", "python/_autosummary/mlx.core.arcsin.rst", "python/_autosummary/mlx.core.arcsinh.rst", "python/_autosummary/mlx.core.arctan.rst", "python/_autosummary/mlx.core.arctanh.rst", "python/_autosummary/mlx.core.argmax.rst", "python/_autosummary/mlx.core.argmin.rst", "python/_autosummary/mlx.core.argpartition.rst", "python/_autosummary/mlx.core.argsort.rst", "python/_autosummary/mlx.core.array.rst", "python/_autosummary/mlx.core.array.T.rst", "python/_autosummary/mlx.core.array.abs.rst", "python/_autosummary/mlx.core.array.all.rst", "python/_autosummary/mlx.core.array.any.rst", "python/_autosummary/mlx.core.array.argmax.rst", "python/_autosummary/mlx.core.array.argmin.rst", "python/_autosummary/mlx.core.array.astype.rst", "python/_autosummary/mlx.core.array.at.rst", "python/_autosummary/mlx.core.array.cos.rst", "python/_autosummary/mlx.core.array.cummax.rst", "python/_autosummary/mlx.core.array.cummin.rst", "python/_autosummary/mlx.core.array.cumprod.rst", "python/_autosummary/mlx.core.array.cumsum.rst", "python/_autosummary/mlx.core.array.diag.rst", "python/_autosummary/mlx.core.array.diagonal.rst", "python/_autosummary/mlx.core.array.dtype.rst", "python/_autosummary/mlx.core.array.exp.rst", "python/_autosummary/mlx.core.array.flatten.rst", "python/_autosummary/mlx.core.array.item.rst", "python/_autosummary/mlx.core.array.itemsize.rst", "python/_autosummary/mlx.core.array.log.rst", "python/_autosummary/mlx.core.array.log10.rst", "python/_autosummary/mlx.core.array.log1p.rst", "python/_autosummary/mlx.core.array.log2.rst", "python/_autosummary/mlx.core.array.logsumexp.rst", "python/_autosummary/mlx.core.array.max.rst", "python/_autosummary/mlx.core.array.mean.rst", "python/_autosummary/mlx.core.array.min.rst", "python/_autosummary/mlx.core.array.moveaxis.rst", "python/_autosummary/mlx.core.array.nbytes.rst", "python/_autosummary/mlx.core.array.ndim.rst", "python/_autosummary/mlx.core.array.prod.rst", "python/_autosummary/mlx.core.array.reciprocal.rst", "python/_autosummary/mlx.core.array.reshape.rst", "python/_autosummary/mlx.core.array.round.rst", "python/_autosummary/mlx.core.array.rsqrt.rst", "python/_autosummary/mlx.core.array.shape.rst", "python/_autosummary/mlx.core.array.sin.rst", "python/_autosummary/mlx.core.array.size.rst", "python/_autosummary/mlx.core.array.split.rst", "python/_autosummary/mlx.core.array.sqrt.rst", "python/_autosummary/mlx.core.array.square.rst", "python/_autosummary/mlx.core.array.squeeze.rst", "python/_autosummary/mlx.core.array.sum.rst", "python/_autosummary/mlx.core.array.swapaxes.rst", "python/_autosummary/mlx.core.array.tolist.rst", "python/_autosummary/mlx.core.array.transpose.rst", "python/_autosummary/mlx.core.array.var.rst", "python/_autosummary/mlx.core.array_equal.rst", "python/_autosummary/mlx.core.atleast_1d.rst", "python/_autosummary/mlx.core.atleast_2d.rst", "python/_autosummary/mlx.core.atleast_3d.rst", "python/_autosummary/mlx.core.broadcast_to.rst", "python/_autosummary/mlx.core.ceil.rst", "python/_autosummary/mlx.core.clip.rst", "python/_autosummary/mlx.core.compile.rst", "python/_autosummary/mlx.core.concatenate.rst", "python/_autosummary/mlx.core.conv1d.rst", "python/_autosummary/mlx.core.conv2d.rst", "python/_autosummary/mlx.core.conv_general.rst", "python/_autosummary/mlx.core.convolve.rst", "python/_autosummary/mlx.core.cos.rst", "python/_autosummary/mlx.core.cosh.rst", "python/_autosummary/mlx.core.cummax.rst", "python/_autosummary/mlx.core.cummin.rst", "python/_autosummary/mlx.core.cumprod.rst", "python/_autosummary/mlx.core.cumsum.rst", "python/_autosummary/mlx.core.default_device.rst", "python/_autosummary/mlx.core.default_stream.rst", "python/_autosummary/mlx.core.dequantize.rst", "python/_autosummary/mlx.core.diag.rst", "python/_autosummary/mlx.core.diagonal.rst", "python/_autosummary/mlx.core.disable_compile.rst", "python/_autosummary/mlx.core.divide.rst", "python/_autosummary/mlx.core.divmod.rst", "python/_autosummary/mlx.core.enable_compile.rst", "python/_autosummary/mlx.core.equal.rst", "python/_autosummary/mlx.core.erf.rst", "python/_autosummary/mlx.core.erfinv.rst", "python/_autosummary/mlx.core.eval.rst", "python/_autosummary/mlx.core.exp.rst", "python/_autosummary/mlx.core.expand_dims.rst", "python/_autosummary/mlx.core.eye.rst", "python/_autosummary/mlx.core.fast.layer_norm.rst", "python/_autosummary/mlx.core.fast.rms_norm.rst", "python/_autosummary/mlx.core.fast.rope.rst", "python/_autosummary/mlx.core.fast.scaled_dot_product_attention.rst", "python/_autosummary/mlx.core.fft.fft.rst", "python/_autosummary/mlx.core.fft.fft2.rst", "python/_autosummary/mlx.core.fft.fftn.rst", "python/_autosummary/mlx.core.fft.ifft.rst", "python/_autosummary/mlx.core.fft.ifft2.rst", "python/_autosummary/mlx.core.fft.ifftn.rst", "python/_autosummary/mlx.core.fft.irfft.rst", "python/_autosummary/mlx.core.fft.irfft2.rst", "python/_autosummary/mlx.core.fft.irfftn.rst", "python/_autosummary/mlx.core.fft.rfft.rst", "python/_autosummary/mlx.core.fft.rfft2.rst", "python/_autosummary/mlx.core.fft.rfftn.rst", "python/_autosummary/mlx.core.flatten.rst", "python/_autosummary/mlx.core.floor.rst", "python/_autosummary/mlx.core.floor_divide.rst", "python/_autosummary/mlx.core.full.rst", "python/_autosummary/mlx.core.grad.rst", "python/_autosummary/mlx.core.greater.rst", "python/_autosummary/mlx.core.greater_equal.rst", "python/_autosummary/mlx.core.identity.rst", "python/_autosummary/mlx.core.inner.rst", "python/_autosummary/mlx.core.isclose.rst", "python/_autosummary/mlx.core.isinf.rst", "python/_autosummary/mlx.core.isnan.rst", "python/_autosummary/mlx.core.isneginf.rst", "python/_autosummary/mlx.core.isposinf.rst", "python/_autosummary/mlx.core.issubdtype.rst", "python/_autosummary/mlx.core.jvp.rst", "python/_autosummary/mlx.core.less.rst", "python/_autosummary/mlx.core.less_equal.rst", "python/_autosummary/mlx.core.linalg.norm.rst", "python/_autosummary/mlx.core.linalg.qr.rst", "python/_autosummary/mlx.core.linspace.rst", "python/_autosummary/mlx.core.load.rst", "python/_autosummary/mlx.core.log.rst", "python/_autosummary/mlx.core.log10.rst", "python/_autosummary/mlx.core.log1p.rst", "python/_autosummary/mlx.core.log2.rst", "python/_autosummary/mlx.core.logaddexp.rst", "python/_autosummary/mlx.core.logical_and.rst", "python/_autosummary/mlx.core.logical_not.rst", "python/_autosummary/mlx.core.logical_or.rst", "python/_autosummary/mlx.core.logsumexp.rst", "python/_autosummary/mlx.core.matmul.rst", "python/_autosummary/mlx.core.max.rst", "python/_autosummary/mlx.core.maximum.rst", "python/_autosummary/mlx.core.mean.rst", "python/_autosummary/mlx.core.metal.get_active_memory.rst", "python/_autosummary/mlx.core.metal.get_cache_memory.rst", "python/_autosummary/mlx.core.metal.get_peak_memory.rst", "python/_autosummary/mlx.core.metal.is_available.rst", "python/_autosummary/mlx.core.metal.set_cache_limit.rst", "python/_autosummary/mlx.core.metal.set_memory_limit.rst", "python/_autosummary/mlx.core.min.rst", "python/_autosummary/mlx.core.minimum.rst", "python/_autosummary/mlx.core.moveaxis.rst", "python/_autosummary/mlx.core.multiply.rst", "python/_autosummary/mlx.core.negative.rst", "python/_autosummary/mlx.core.new_stream.rst", "python/_autosummary/mlx.core.ones.rst", "python/_autosummary/mlx.core.ones_like.rst", "python/_autosummary/mlx.core.outer.rst", "python/_autosummary/mlx.core.pad.rst", "python/_autosummary/mlx.core.partition.rst", "python/_autosummary/mlx.core.prod.rst", "python/_autosummary/mlx.core.quantize.rst", "python/_autosummary/mlx.core.quantized_matmul.rst", "python/_autosummary/mlx.core.random.bernoulli.rst", "python/_autosummary/mlx.core.random.categorical.rst", "python/_autosummary/mlx.core.random.gumbel.rst", "python/_autosummary/mlx.core.random.key.rst", "python/_autosummary/mlx.core.random.normal.rst", "python/_autosummary/mlx.core.random.randint.rst", "python/_autosummary/mlx.core.random.seed.rst", "python/_autosummary/mlx.core.random.split.rst", "python/_autosummary/mlx.core.random.truncated_normal.rst", "python/_autosummary/mlx.core.random.uniform.rst", "python/_autosummary/mlx.core.reciprocal.rst", "python/_autosummary/mlx.core.repeat.rst", "python/_autosummary/mlx.core.reshape.rst", "python/_autosummary/mlx.core.round.rst", "python/_autosummary/mlx.core.rsqrt.rst", "python/_autosummary/mlx.core.save.rst", "python/_autosummary/mlx.core.save_gguf.rst", "python/_autosummary/mlx.core.save_safetensors.rst", "python/_autosummary/mlx.core.savez.rst", "python/_autosummary/mlx.core.savez_compressed.rst", "python/_autosummary/mlx.core.set_default_device.rst", "python/_autosummary/mlx.core.set_default_stream.rst", "python/_autosummary/mlx.core.sigmoid.rst", "python/_autosummary/mlx.core.sign.rst", "python/_autosummary/mlx.core.sin.rst", "python/_autosummary/mlx.core.sinh.rst", "python/_autosummary/mlx.core.softmax.rst", "python/_autosummary/mlx.core.sort.rst", "python/_autosummary/mlx.core.split.rst", "python/_autosummary/mlx.core.sqrt.rst", "python/_autosummary/mlx.core.square.rst", "python/_autosummary/mlx.core.squeeze.rst", "python/_autosummary/mlx.core.stack.rst", "python/_autosummary/mlx.core.stop_gradient.rst", "python/_autosummary/mlx.core.stream.rst", "python/_autosummary/mlx.core.subtract.rst", "python/_autosummary/mlx.core.sum.rst", "python/_autosummary/mlx.core.swapaxes.rst", "python/_autosummary/mlx.core.take.rst", "python/_autosummary/mlx.core.take_along_axis.rst", "python/_autosummary/mlx.core.tan.rst", "python/_autosummary/mlx.core.tanh.rst", "python/_autosummary/mlx.core.tensordot.rst", "python/_autosummary/mlx.core.tile.rst", "python/_autosummary/mlx.core.topk.rst", "python/_autosummary/mlx.core.transpose.rst", "python/_autosummary/mlx.core.tri.rst", "python/_autosummary/mlx.core.tril.rst", "python/_autosummary/mlx.core.triu.rst", "python/_autosummary/mlx.core.value_and_grad.rst", "python/_autosummary/mlx.core.var.rst", "python/_autosummary/mlx.core.vjp.rst", "python/_autosummary/mlx.core.vmap.rst", "python/_autosummary/mlx.core.where.rst", "python/_autosummary/mlx.core.zeros.rst", "python/_autosummary/mlx.core.zeros_like.rst", "python/_autosummary/mlx.nn.value_and_grad.rst", "python/_autosummary/mlx.utils.tree_flatten.rst", "python/_autosummary/mlx.utils.tree_map.rst", "python/_autosummary/mlx.utils.tree_unflatten.rst", "python/_autosummary/stream_class.rst", "python/array.rst", "python/data_types.rst", "python/devices_and_streams.rst", "python/fast.rst", "python/fft.rst", "python/linalg.rst", "python/metal.rst", "python/nn.rst", "python/nn/_autosummary/mlx.nn.ALiBi.rst", "python/nn/_autosummary/mlx.nn.AvgPool1d.rst", "python/nn/_autosummary/mlx.nn.AvgPool2d.rst", "python/nn/_autosummary/mlx.nn.BatchNorm.rst", "python/nn/_autosummary/mlx.nn.Conv1d.rst", "python/nn/_autosummary/mlx.nn.Conv2d.rst", "python/nn/_autosummary/mlx.nn.Dropout.rst", "python/nn/_autosummary/mlx.nn.Dropout2d.rst", "python/nn/_autosummary/mlx.nn.Dropout3d.rst", "python/nn/_autosummary/mlx.nn.Embedding.rst", "python/nn/_autosummary/mlx.nn.GELU.rst", "python/nn/_autosummary/mlx.nn.GRU.rst", "python/nn/_autosummary/mlx.nn.GroupNorm.rst", "python/nn/_autosummary/mlx.nn.InstanceNorm.rst", "python/nn/_autosummary/mlx.nn.LSTM.rst", "python/nn/_autosummary/mlx.nn.LayerNorm.rst", "python/nn/_autosummary/mlx.nn.Linear.rst", "python/nn/_autosummary/mlx.nn.MaxPool1d.rst", "python/nn/_autosummary/mlx.nn.MaxPool2d.rst", "python/nn/_autosummary/mlx.nn.Mish.rst", "python/nn/_autosummary/mlx.nn.Module.apply.rst", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules.rst", "python/nn/_autosummary/mlx.nn.Module.children.rst", "python/nn/_autosummary/mlx.nn.Module.eval.rst", "python/nn/_autosummary/mlx.nn.Module.filter_and_map.rst", "python/nn/_autosummary/mlx.nn.Module.freeze.rst", "python/nn/_autosummary/mlx.nn.Module.leaf_modules.rst", "python/nn/_autosummary/mlx.nn.Module.load_weights.rst", "python/nn/_autosummary/mlx.nn.Module.modules.rst", "python/nn/_autosummary/mlx.nn.Module.named_modules.rst", "python/nn/_autosummary/mlx.nn.Module.parameters.rst", "python/nn/_autosummary/mlx.nn.Module.save_weights.rst", "python/nn/_autosummary/mlx.nn.Module.set_dtype.rst", "python/nn/_autosummary/mlx.nn.Module.state.rst", "python/nn/_autosummary/mlx.nn.Module.train.rst", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters.rst", "python/nn/_autosummary/mlx.nn.Module.training.rst", "python/nn/_autosummary/mlx.nn.Module.unfreeze.rst", "python/nn/_autosummary/mlx.nn.Module.update.rst", "python/nn/_autosummary/mlx.nn.Module.update_modules.rst", "python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst", "python/nn/_autosummary/mlx.nn.PReLU.rst", "python/nn/_autosummary/mlx.nn.QuantizedLinear.rst", "python/nn/_autosummary/mlx.nn.RMSNorm.rst", "python/nn/_autosummary/mlx.nn.RNN.rst", "python/nn/_autosummary/mlx.nn.ReLU.rst", "python/nn/_autosummary/mlx.nn.RoPE.rst", "python/nn/_autosummary/mlx.nn.SELU.rst", "python/nn/_autosummary/mlx.nn.Sequential.rst", "python/nn/_autosummary/mlx.nn.SiLU.rst", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst", "python/nn/_autosummary/mlx.nn.Softshrink.rst", "python/nn/_autosummary/mlx.nn.Step.rst", "python/nn/_autosummary/mlx.nn.Transformer.rst", "python/nn/_autosummary/mlx.nn.Upsample.rst", "python/nn/_autosummary/mlx.nn.init.constant.rst", "python/nn/_autosummary/mlx.nn.init.glorot_normal.rst", "python/nn/_autosummary/mlx.nn.init.glorot_uniform.rst", "python/nn/_autosummary/mlx.nn.init.he_normal.rst", "python/nn/_autosummary/mlx.nn.init.he_uniform.rst", "python/nn/_autosummary/mlx.nn.init.identity.rst", "python/nn/_autosummary/mlx.nn.init.normal.rst", "python/nn/_autosummary/mlx.nn.init.uniform.rst", "python/nn/_autosummary_functions/mlx.nn.elu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst", "python/nn/_autosummary_functions/mlx.nn.glu.rst", "python/nn/_autosummary_functions/mlx.nn.hardswish.rst", "python/nn/_autosummary_functions/mlx.nn.leaky_relu.rst", "python/nn/_autosummary_functions/mlx.nn.log_sigmoid.rst", "python/nn/_autosummary_functions/mlx.nn.log_softmax.rst", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst", "python/nn/_autosummary_functions/mlx.nn.mish.rst", "python/nn/_autosummary_functions/mlx.nn.prelu.rst", "python/nn/_autosummary_functions/mlx.nn.relu.rst", "python/nn/_autosummary_functions/mlx.nn.relu6.rst", "python/nn/_autosummary_functions/mlx.nn.selu.rst", "python/nn/_autosummary_functions/mlx.nn.sigmoid.rst", "python/nn/_autosummary_functions/mlx.nn.silu.rst", "python/nn/_autosummary_functions/mlx.nn.softmax.rst", "python/nn/_autosummary_functions/mlx.nn.softplus.rst", "python/nn/_autosummary_functions/mlx.nn.softshrink.rst", "python/nn/_autosummary_functions/mlx.nn.step.rst", "python/nn/_autosummary_functions/mlx.nn.tanh.rst", "python/nn/functions.rst", "python/nn/init.rst", "python/nn/layers.rst", "python/nn/losses.rst", "python/nn/module.rst", "python/ops.rst", "python/optimizers.rst", "python/optimizers/_autosummary/mlx.optimizers.AdaDelta.rst", "python/optimizers/_autosummary/mlx.optimizers.Adafactor.rst", "python/optimizers/_autosummary/mlx.optimizers.Adagrad.rst", "python/optimizers/_autosummary/mlx.optimizers.Adam.rst", "python/optimizers/_autosummary/mlx.optimizers.AdamW.rst", "python/optimizers/_autosummary/mlx.optimizers.Adamax.rst", "python/optimizers/_autosummary/mlx.optimizers.Lion.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.rst", "python/optimizers/_autosummary/mlx.optimizers.RMSprop.rst", "python/optimizers/_autosummary/mlx.optimizers.SGD.rst", "python/optimizers/_autosummary/mlx.optimizers.cosine_decay.rst", "python/optimizers/_autosummary/mlx.optimizers.exponential_decay.rst", "python/optimizers/_autosummary/mlx.optimizers.join_schedules.rst", "python/optimizers/_autosummary/mlx.optimizers.linear_schedule.rst", "python/optimizers/_autosummary/mlx.optimizers.step_decay.rst", "python/optimizers/common_optimizers.rst", "python/optimizers/optimizer.rst", "python/optimizers/schedulers.rst", "python/random.rst", "python/transforms.rst", "python/tree_utils.rst", "usage/compile.rst", "usage/function_transforms.rst", "usage/indexing.rst", "usage/lazy_evaluation.rst", "usage/numpy.rst", "usage/quick_start.rst", "usage/saving_and_loading.rst", "usage/unified_memory.rst", "usage/using_streams.rst"], "titles": ["Operations", "Developer Documentation", "Metal Debugger", "Linear Regression", "LLM inference", "Multi-Layer Perceptron", "MLX", "Build and Install", "mlx.core.Device", "mlx.core.Dtype", "mlx.core.DtypeCategory", "mlx.core.abs", "mlx.core.add", "mlx.core.all", "mlx.core.allclose", "mlx.core.any", "mlx.core.arange", "mlx.core.arccos", "mlx.core.arccosh", "mlx.core.arcsin", "mlx.core.arcsinh", "mlx.core.arctan", "mlx.core.arctanh", "mlx.core.argmax", "mlx.core.argmin", "mlx.core.argpartition", "mlx.core.argsort", "mlx.core.array", "mlx.core.array.T", "mlx.core.array.abs", "mlx.core.array.all", "mlx.core.array.any", "mlx.core.array.argmax", "mlx.core.array.argmin", "mlx.core.array.astype", "mlx.core.array.at", "mlx.core.array.cos", "mlx.core.array.cummax", "mlx.core.array.cummin", "mlx.core.array.cumprod", "mlx.core.array.cumsum", "mlx.core.array.diag", "mlx.core.array.diagonal", "mlx.core.array.dtype", "mlx.core.array.exp", "mlx.core.array.flatten", "mlx.core.array.item", "mlx.core.array.itemsize", "mlx.core.array.log", "mlx.core.array.log10", "mlx.core.array.log1p", "mlx.core.array.log2", "mlx.core.array.logsumexp", "mlx.core.array.max", "mlx.core.array.mean", "mlx.core.array.min", "mlx.core.array.moveaxis", "mlx.core.array.nbytes", "mlx.core.array.ndim", "mlx.core.array.prod", "mlx.core.array.reciprocal", "mlx.core.array.reshape", "mlx.core.array.round", "mlx.core.array.rsqrt", "mlx.core.array.shape", "mlx.core.array.sin", "mlx.core.array.size", "mlx.core.array.split", "mlx.core.array.sqrt", "mlx.core.array.square", "mlx.core.array.squeeze", "mlx.core.array.sum", "mlx.core.array.swapaxes", "mlx.core.array.tolist", "mlx.core.array.transpose", "mlx.core.array.var", "mlx.core.array_equal", "mlx.core.atleast_1d", "mlx.core.atleast_2d", "mlx.core.atleast_3d", "mlx.core.broadcast_to", "mlx.core.ceil", "mlx.core.clip", "mlx.core.compile", "mlx.core.concatenate", "mlx.core.conv1d", "mlx.core.conv2d", "mlx.core.conv_general", "mlx.core.convolve", "mlx.core.cos", "mlx.core.cosh", "mlx.core.cummax", "mlx.core.cummin", "mlx.core.cumprod", "mlx.core.cumsum", "mlx.core.default_device", "mlx.core.default_stream", "mlx.core.dequantize", "mlx.core.diag", "mlx.core.diagonal", "mlx.core.disable_compile", "mlx.core.divide", "mlx.core.divmod", "mlx.core.enable_compile", "mlx.core.equal", "mlx.core.erf", "mlx.core.erfinv", "mlx.core.eval", "mlx.core.exp", "mlx.core.expand_dims", "mlx.core.eye", "mlx.core.fast.layer_norm", "mlx.core.fast.rms_norm", "mlx.core.fast.rope", "mlx.core.fast.scaled_dot_product_attention", "mlx.core.fft.fft", "mlx.core.fft.fft2", "mlx.core.fft.fftn", "mlx.core.fft.ifft", "mlx.core.fft.ifft2", "mlx.core.fft.ifftn", "mlx.core.fft.irfft", "mlx.core.fft.irfft2", "mlx.core.fft.irfftn", "mlx.core.fft.rfft", "mlx.core.fft.rfft2", "mlx.core.fft.rfftn", "mlx.core.flatten", "mlx.core.floor", "mlx.core.floor_divide", "mlx.core.full", "mlx.core.grad", "mlx.core.greater", "mlx.core.greater_equal", "mlx.core.identity", "mlx.core.inner", "mlx.core.isclose", "mlx.core.isinf", "mlx.core.isnan", "mlx.core.isneginf", "mlx.core.isposinf", "mlx.core.issubdtype", "mlx.core.jvp", "mlx.core.less", "mlx.core.less_equal", "mlx.core.linalg.norm", "mlx.core.linalg.qr", "mlx.core.linspace", "mlx.core.load", "mlx.core.log", "mlx.core.log10", "mlx.core.log1p", "mlx.core.log2", "mlx.core.logaddexp", "mlx.core.logical_and", "mlx.core.logical_not", "mlx.core.logical_or", "mlx.core.logsumexp", "mlx.core.matmul", "mlx.core.max", "mlx.core.maximum", "mlx.core.mean", "mlx.core.metal.get_active_memory", "mlx.core.metal.get_cache_memory", "mlx.core.metal.get_peak_memory", "mlx.core.metal.is_available", "mlx.core.metal.set_cache_limit", "mlx.core.metal.set_memory_limit", "mlx.core.min", "mlx.core.minimum", "mlx.core.moveaxis", "mlx.core.multiply", "mlx.core.negative", "mlx.core.new_stream", "mlx.core.ones", "mlx.core.ones_like", "mlx.core.outer", "mlx.core.pad", "mlx.core.partition", "mlx.core.prod", "mlx.core.quantize", "mlx.core.quantized_matmul", "mlx.core.random.bernoulli", "mlx.core.random.categorical", "mlx.core.random.gumbel", "mlx.core.random.key", "mlx.core.random.normal", "mlx.core.random.randint", "mlx.core.random.seed", "mlx.core.random.split", "mlx.core.random.truncated_normal", "mlx.core.random.uniform", "mlx.core.reciprocal", "mlx.core.repeat", "mlx.core.reshape", "mlx.core.round", "mlx.core.rsqrt", "mlx.core.save", "mlx.core.save_gguf", "mlx.core.save_safetensors", "mlx.core.savez", "mlx.core.savez_compressed", "mlx.core.set_default_device", "mlx.core.set_default_stream", "mlx.core.sigmoid", "mlx.core.sign", "mlx.core.sin", "mlx.core.sinh", "mlx.core.softmax", "mlx.core.sort", "mlx.core.split", "mlx.core.sqrt", "mlx.core.square", "mlx.core.squeeze", "mlx.core.stack", "mlx.core.stop_gradient", "mlx.core.stream", "mlx.core.subtract", "mlx.core.sum", "mlx.core.swapaxes", "mlx.core.take", "mlx.core.take_along_axis", "mlx.core.tan", "mlx.core.tanh", "mlx.core.tensordot", "mlx.core.tile", "mlx.core.topk", "mlx.core.transpose", "mlx.core.tri", "mlx.core.tril", "mlx.core.triu", "mlx.core.value_and_grad", "mlx.core.var", "mlx.core.vjp", "mlx.core.vmap", "mlx.core.where", "mlx.core.zeros", "mlx.core.zeros_like", "mlx.nn.value_and_grad", "mlx.utils.tree_flatten", "mlx.utils.tree_map", "mlx.utils.tree_unflatten", "mlx.core.Stream", "Array", "Data Types", "Devices and Streams", "Fast", "FFT", "Linear Algebra", "Metal", "Neural Networks", "mlx.nn.ALiBi", "mlx.nn.AvgPool1d", "mlx.nn.AvgPool2d", "mlx.nn.BatchNorm", "mlx.nn.Conv1d", "mlx.nn.Conv2d", "mlx.nn.Dropout", "mlx.nn.Dropout2d", "mlx.nn.Dropout3d", "mlx.nn.Embedding", "mlx.nn.GELU", "mlx.nn.GRU", "mlx.nn.GroupNorm", "mlx.nn.InstanceNorm", "mlx.nn.LSTM", "mlx.nn.LayerNorm", "mlx.nn.Linear", "mlx.nn.MaxPool1d", "mlx.nn.MaxPool2d", "mlx.nn.Mish", "mlx.nn.Module.apply", "mlx.nn.Module.apply_to_modules", "mlx.nn.Module.children", "mlx.nn.Module.eval", "mlx.nn.Module.filter_and_map", "mlx.nn.Module.freeze", "mlx.nn.Module.leaf_modules", "mlx.nn.Module.load_weights", "mlx.nn.Module.modules", "mlx.nn.Module.named_modules", "mlx.nn.Module.parameters", "mlx.nn.Module.save_weights", "mlx.nn.Module.set_dtype", "mlx.nn.Module.state", "mlx.nn.Module.train", "mlx.nn.Module.trainable_parameters", "mlx.nn.Module.training", "mlx.nn.Module.unfreeze", "mlx.nn.Module.update", "mlx.nn.Module.update_modules", "mlx.nn.MultiHeadAttention", "mlx.nn.PReLU", "mlx.nn.QuantizedLinear", "mlx.nn.RMSNorm", "mlx.nn.RNN", "mlx.nn.ReLU", "mlx.nn.RoPE", "mlx.nn.SELU", "mlx.nn.Sequential", "mlx.nn.SiLU", "mlx.nn.SinusoidalPositionalEncoding", "mlx.nn.Softshrink", "mlx.nn.Step", "mlx.nn.Transformer", "mlx.nn.Upsample", "mlx.nn.init.constant", "mlx.nn.init.glorot_normal", "mlx.nn.init.glorot_uniform", "mlx.nn.init.he_normal", "mlx.nn.init.he_uniform", "mlx.nn.init.identity", "mlx.nn.init.normal", "mlx.nn.init.uniform", "mlx.nn.elu", "mlx.nn.gelu", "mlx.nn.gelu_approx", "mlx.nn.gelu_fast_approx", "mlx.nn.glu", "mlx.nn.hardswish", "mlx.nn.leaky_relu", "mlx.nn.log_sigmoid", "mlx.nn.log_softmax", "mlx.nn.losses.binary_cross_entropy", "mlx.nn.losses.cosine_similarity_loss", "mlx.nn.losses.cross_entropy", "mlx.nn.losses.gaussian_nll_loss", "mlx.nn.losses.hinge_loss", "mlx.nn.losses.huber_loss", "mlx.nn.losses.kl_div_loss", "mlx.nn.losses.l1_loss", "mlx.nn.losses.log_cosh_loss", "mlx.nn.losses.margin_ranking_loss", "mlx.nn.losses.mse_loss", "mlx.nn.losses.nll_loss", "mlx.nn.losses.smooth_l1_loss", "mlx.nn.losses.triplet_loss", "mlx.nn.mish", "mlx.nn.prelu", "mlx.nn.relu", "mlx.nn.relu6", "mlx.nn.selu", "mlx.nn.sigmoid", "mlx.nn.silu", "mlx.nn.softmax", "mlx.nn.softplus", "mlx.nn.softshrink", "mlx.nn.step", "mlx.nn.tanh", "Functions", "Initializers", "Layers", "Loss Functions", "Module", "Operations", "Optimizers", "mlx.optimizers.AdaDelta", "mlx.optimizers.Adafactor", "mlx.optimizers.Adagrad", "mlx.optimizers.Adam", "mlx.optimizers.AdamW", "mlx.optimizers.Adamax", "mlx.optimizers.Lion", "mlx.optimizers.Optimizer.apply_gradients", "mlx.optimizers.Optimizer.init", "mlx.optimizers.Optimizer.state", "mlx.optimizers.Optimizer.update", "mlx.optimizers.RMSprop", "mlx.optimizers.SGD", "mlx.optimizers.cosine_decay", "mlx.optimizers.exponential_decay", "mlx.optimizers.join_schedules", "mlx.optimizers.linear_schedule", "mlx.optimizers.step_decay", "Common Optimizers", "Optimizer", "Schedulers", "Random", "Transforms", "Tree Utils", "Compilation", "Function Transforms", "Indexing Arrays", "Lazy Evaluation", "Conversion to NumPy and Other Frameworks", "Quick Start Guide", "Saving and Loading Arrays", "Unified Memory", "Using Streams"], "terms": {"mlx": [1, 2, 3, 4, 5, 7, 250, 350, 353, 355, 377, 379, 380, 381, 382, 383, 384, 385, 386, 387], "provid": [1, 4, 97, 131, 224, 231, 240, 250, 271, 276, 278, 288, 289, 290, 293, 304, 305, 349, 353, 386, 388], "open": [1, 2, 7, 16, 187, 191], "flexibl": [1, 6], "which": [1, 4, 5, 6, 7, 16, 34, 83, 87, 99, 107, 113, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 131, 137, 138, 139, 140, 142, 145, 146, 148, 180, 183, 184, 193, 194, 197, 198, 199, 200, 201, 213, 214, 220, 231, 233, 234, 253, 258, 259, 261, 269, 271, 275, 297, 325, 328, 332, 335, 350, 363, 364, 377, 380, 381, 382, 383, 387, 388], "user": [1, 2, 4, 250], "mai": [1, 145, 258, 381, 382], "add": [1, 2, 4, 35, 109, 153, 177, 180, 255, 256, 381, 387], "special": 1, "without": [1, 4, 6, 215, 291, 349, 379, 380, 383, 384, 387], "much": [1, 4, 252, 253, 268, 269, 380, 383], "hassl": 1, "while": [1, 2, 4, 7, 194, 297, 383, 384], "librari": [1, 7, 250], "suppli": 1, "effici": [1, 4, 6, 258, 297, 383, 385], "can": [1, 2, 4, 6, 7, 12, 16, 61, 74, 83, 99, 100, 101, 102, 104, 107, 132, 133, 143, 144, 145, 153, 160, 169, 171, 182, 183, 187, 190, 191, 198, 217, 231, 250, 253, 260, 269, 275, 288, 299, 305, 325, 350, 353, 355, 363, 364, 377, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388], "compos": [1, 6, 250, 380, 381, 385], "ani": [1, 4, 6, 16, 83, 239, 240, 241, 250, 261, 271, 272, 275, 284, 293, 304, 305, 350, 372, 379, 380, 381, 383, 385, 386, 387], "number": [1, 10, 16, 57, 66, 83, 86, 87, 97, 110, 131, 134, 142, 147, 177, 180, 181, 183, 186, 189, 191, 193, 195, 224, 225, 228, 231, 233, 234, 250, 254, 255, 256, 258, 259, 263, 264, 291, 292, 304, 305, 307, 308, 309, 310, 369, 371, 372, 377, 380, 381, 388], "applic": [1, 7], "aris": [1, 384], "case": [1, 4, 117, 120, 121, 123, 124, 125, 126, 127, 146, 158, 194, 213, 253, 258, 269, 303, 335, 341, 346, 347, 363, 364, 380, 381, 385, 386, 387, 388], "where": [1, 5, 110, 136, 180, 231, 234, 252, 253, 254, 255, 256, 257, 258, 259, 261, 262, 263, 264, 265, 266, 267, 268, 269, 275, 292, 294, 295, 303, 309, 310, 314, 315, 316, 317, 326, 332, 338, 341, 343, 347, 364, 381, 382], "new": [1, 5, 80, 99, 170, 173, 194, 214, 227, 240, 283, 291, 353, 355, 366, 371, 380, 382, 383, 384], "function": [1, 2, 3, 4, 5, 6, 14, 83, 102, 105, 106, 131, 136, 142, 145, 146, 158, 204, 231, 233, 234, 238, 240, 250, 261, 270, 272, 276, 283, 288, 292, 295, 296, 298, 299, 300, 302, 303, 304, 315, 316, 317, 318, 319, 321, 322, 337, 342, 344, 345, 346, 347, 348, 350, 355, 364, 377, 379, 382, 383, 384, 386], "highli": [1, 7], "optim": [1, 2, 3, 5, 6, 289, 380, 381, 383], "ar": [1, 3, 4, 5, 6, 7, 14, 16, 76, 80, 82, 83, 87, 88, 99, 107, 110, 116, 117, 119, 120, 122, 123, 125, 126, 127, 131, 136, 137, 138, 139, 140, 141, 142, 145, 146, 148, 158, 167, 176, 177, 178, 180, 181, 182, 183, 184, 187, 190, 191, 200, 201, 213, 214, 220, 231, 233, 234, 239, 240, 244, 254, 255, 256, 257, 258, 259, 263, 264, 266, 267, 278, 291, 293, 305, 323, 325, 326, 349, 353, 362, 364, 379, 380, 381, 382, 383, 384, 385, 386, 387], "need": [1, 4, 5, 6, 76, 180, 250, 289, 290, 301, 304, 377, 381, 383, 384, 385, 387], "For": [1, 4, 7, 35, 114, 141, 145, 180, 241, 250, 254, 258, 271, 276, 285, 288, 293, 297, 301, 305, 307, 308, 309, 310, 350, 377, 380, 381, 382, 383, 384, 385, 386, 387], "you": [1, 2, 4, 5, 6, 7, 250, 301, 304, 350, 377, 380, 381, 382, 384, 386, 387], "design": [1, 3, 6, 377, 387], "your": [1, 4, 7, 353, 381, 383], "own": [1, 7, 384], "link": [1, 7], "top": [1, 226, 267, 305], "core": [1, 3, 4, 5, 250, 252, 253, 254, 264, 268, 269, 278, 281, 283, 286, 305, 306, 307, 308, 309, 310, 311, 312, 313, 323, 325, 332, 350, 353, 355, 380, 384, 385], "we": [1, 3, 4, 5, 97, 180, 181, 250, 260, 299, 360, 362, 377, 379, 380, 381, 383, 387], "inner": [1, 380], "work": [1, 2, 4, 7, 167, 380, 381, 382, 383], "go": [1, 4, 381], "over": [1, 4, 5, 13, 15, 23, 24, 25, 26, 85, 86, 87, 91, 92, 93, 94, 117, 120, 123, 126, 135, 145, 147, 157, 159, 161, 168, 178, 179, 196, 208, 209, 218, 224, 226, 232, 254, 255, 256, 263, 266, 294, 325, 369, 372, 381], "simpl": [1, 4, 5, 250, 260, 349, 380, 381, 383], "learn": [1, 3, 5, 6, 254, 263, 264, 266, 292, 294, 356, 357, 358, 359, 360, 361, 362, 367, 368], "step": [1, 2, 4, 5, 16, 250, 262, 265, 295, 357, 364, 369, 371, 372, 373, 380], "involv": [1, 355, 380], "ad": [1, 3, 7, 111, 264, 353, 356, 357, 358, 359, 360, 361, 367, 383, 386], "let": [1, 3, 4, 380, 381, 383, 384], "": [1, 3, 4, 5, 43, 47, 58, 83, 96, 97, 116, 117, 119, 120, 122, 123, 125, 126, 131, 145, 148, 161, 176, 180, 183, 195, 198, 199, 216, 231, 232, 234, 238, 250, 253, 262, 265, 269, 275, 276, 278, 282, 283, 284, 288, 295, 355, 364, 365, 377, 380, 381, 383, 384, 385, 386, 387], "sai": [1, 4, 350, 383], "would": [1, 4, 305, 382, 383, 384, 387], "like": [1, 4, 6, 141, 175, 237, 259, 331, 364, 366, 380, 381, 383, 384, 385, 387], "an": [1, 2, 4, 5, 7, 9, 13, 15, 27, 77, 78, 79, 80, 85, 86, 87, 107, 110, 111, 114, 127, 130, 134, 145, 148, 167, 170, 174, 175, 177, 179, 180, 181, 193, 194, 195, 210, 213, 219, 220, 221, 224, 225, 228, 234, 236, 237, 239, 240, 250, 252, 253, 257, 263, 265, 266, 267, 268, 269, 271, 291, 292, 293, 295, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 316, 338, 350, 356, 366, 370, 375, 377, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388], "take": [1, 4, 5, 83, 131, 142, 160, 169, 175, 181, 221, 231, 233, 234, 237, 291, 377, 381, 382, 386, 387, 388], "two": [1, 12, 14, 76, 78, 99, 101, 104, 116, 119, 125, 132, 133, 136, 143, 144, 146, 153, 158, 160, 169, 171, 176, 219, 253, 265, 269, 293, 305, 318, 324, 380, 381, 382, 387], "arrai": [1, 4, 5, 6, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 250, 254, 265, 271, 278, 281, 286, 292, 305, 306, 307, 308, 309, 310, 311, 312, 313, 315, 318, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 347, 350, 353, 356, 357, 358, 359, 360, 361, 362, 367, 368, 369, 370, 371, 372, 373, 380, 381, 383, 384, 385, 387], "x": [1, 3, 4, 5, 35, 105, 111, 112, 134, 145, 181, 184, 195, 200, 204, 229, 230, 235, 240, 250, 252, 253, 254, 261, 263, 264, 266, 267, 268, 269, 270, 271, 292, 294, 296, 301, 303, 305, 314, 315, 316, 317, 318, 319, 320, 321, 322, 335, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 353, 355, 362, 380, 381, 382, 383, 384, 385, 387], "y": [1, 3, 4, 5, 35, 235, 250, 254, 258, 263, 264, 266, 267, 294, 327, 332, 335, 355, 358, 380, 381, 383, 384], "scale": [1, 4, 97, 111, 112, 113, 114, 180, 181, 186, 258, 259, 266, 291, 297, 298, 301, 305, 341, 357], "them": [1, 4, 250, 276, 288, 387], "both": [1, 12, 101, 102, 104, 132, 133, 141, 143, 144, 145, 153, 160, 169, 171, 183, 217, 252, 253, 264, 265, 268, 269, 305, 355, 380, 381, 385, 387], "some": [1, 3, 4, 5, 276, 288, 364, 380, 381, 383], "coeffici": [1, 356, 357, 359, 360, 361, 362], "alpha": [1, 180, 314, 336, 338, 341, 360, 367], "beta": [1, 97, 180, 254, 263, 264, 266, 335, 359, 360, 361, 362], "respect": [1, 3, 5, 111, 112, 131, 180, 231, 240, 250, 254, 261, 263, 264, 266, 353, 381, 385], "togeth": [1, 5, 180, 240], "get": [1, 3, 5, 7, 86, 87, 95, 96, 162, 163, 164, 185, 250, 380, 381, 383, 387], "z": [1, 262, 380, 383], "well": [1, 4, 250, 276, 288, 291, 383], "veri": [1, 4, 291, 383, 387], "easili": 1, "do": [1, 4, 7, 250, 277, 288, 350, 353, 360, 380, 381, 383], "just": [1, 5, 266, 380, 382], "write": [1, 4, 250, 384], "out": [1, 7, 252, 253, 258, 259, 268, 269, 285, 380, 381, 382], "follow": [1, 4, 5, 6, 7, 16, 88, 97, 145, 180, 250, 316, 317, 329, 356, 357, 358, 359, 360, 361, 362, 368, 377, 380, 381, 387], "import": [1, 3, 4, 5, 7, 145, 200, 231, 239, 240, 241, 250, 252, 253, 254, 264, 268, 269, 278, 305, 323, 325, 332, 350, 353, 380, 381, 382, 383, 384, 385], "mx": [1, 3, 4, 5, 35, 127, 141, 145, 146, 148, 200, 231, 250, 252, 253, 254, 264, 268, 269, 271, 278, 282, 296, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 320, 323, 324, 325, 329, 332, 339, 348, 350, 353, 355, 377, 380, 381, 382, 383, 384, 385, 386, 387, 388], "def": [1, 3, 4, 5, 231, 250, 353, 380, 381, 382, 383, 384, 387], "simple_axpbi": 1, "float": [1, 10, 14, 16, 73, 111, 112, 113, 114, 129, 130, 136, 141, 145, 181, 182, 186, 244, 254, 257, 258, 259, 263, 264, 266, 271, 283, 294, 297, 301, 303, 304, 305, 306, 307, 308, 309, 310, 312, 313, 324, 325, 326, 328, 332, 335, 336, 346, 347, 356, 357, 358, 359, 360, 361, 362, 367, 368, 369, 370, 372, 373], "return": [1, 3, 4, 5, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 46, 64, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 163, 166, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 189, 190, 191, 192, 193, 194, 195, 196, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 250, 262, 265, 271, 272, 273, 275, 276, 277, 278, 279, 280, 281, 285, 286, 288, 289, 290, 293, 295, 306, 307, 308, 309, 310, 311, 312, 313, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 350, 353, 363, 379, 380, 381, 382, 383, 384, 386, 387], "thi": [1, 4, 5, 7, 13, 14, 15, 16, 23, 24, 25, 26, 103, 136, 142, 145, 146, 153, 157, 158, 159, 161, 162, 168, 178, 179, 183, 203, 208, 209, 210, 218, 220, 226, 232, 250, 257, 258, 259, 262, 265, 272, 273, 275, 276, 279, 280, 281, 286, 288, 289, 290, 291, 293, 295, 303, 307, 308, 309, 310, 316, 317, 318, 331, 347, 353, 364, 379, 380, 381, 383, 384, 386], "perform": [1, 2, 4, 6, 87, 91, 92, 93, 94, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 158, 181, 195, 208, 220, 250, 263, 304, 309, 310, 380, 382, 383, 387], "leav": [1, 107, 240], "differenti": [1, 6], "howev": [1, 250, 261, 263, 364, 377, 380, 383, 384], "vector": [1, 3, 6, 135, 142, 145, 220, 233, 234, 260, 325, 385], "math": [1, 4, 336, 380], "often": [1, 259], "realiz": 1, "axpbi": 1, "routin": 1, "defin": [1, 3, 4, 5, 7, 145, 181, 239, 384], "same": [1, 4, 7, 14, 35, 76, 80, 83, 86, 87, 88, 111, 112, 121, 124, 125, 126, 131, 136, 142, 177, 183, 195, 233, 235, 250, 253, 254, 257, 263, 264, 269, 293, 306, 307, 308, 309, 310, 311, 312, 313, 325, 336, 353, 363, 377, 380, 382, 387], "realli": [1, 266], "part": [1, 381, 382], "doe": [1, 4, 7, 162, 250, 380, 382, 383, 384], "fast": [1, 6, 261, 317, 387], "so": [1, 4, 7, 131, 231, 257, 305, 355, 380, 383, 387], "decid": [1, 240, 275], "want": [1, 4, 381, 387], "reli": 1, "acceler": [1, 254], "framework": [1, 6], "continu": [1, 381], "impos": 1, "our": [1, 4, 5, 299, 356, 357, 358, 359, 361, 362], "assumpt": 1, "also": [1, 4, 5, 6, 7, 10, 12, 100, 101, 102, 104, 117, 120, 123, 126, 132, 133, 143, 144, 153, 160, 169, 171, 180, 217, 238, 250, 275, 289, 291, 293, 300, 315, 341, 343, 349, 355, 380, 381, 382, 383, 384, 385, 388], "assum": [1, 4, 146, 240, 250, 252, 253, 263, 268, 269], "how": [1, 4, 5, 250, 252, 253, 255, 256, 260, 268, 269, 305, 363, 380, 382, 387], "gradient": [1, 3, 5, 131, 215, 231, 238, 250, 276, 289, 293, 304, 331, 353, 355, 356, 357, 359, 360, 361, 362, 363, 366, 368, 380, 381, 382, 383, 384, 385], "ins": 1, "what": [1, 4, 240], "coincid": 1, "right": [1, 7, 180, 252, 253, 261, 268, 269, 305, 316, 317, 326, 328, 336], "place": [1, 4, 35, 195, 383, 384], "cours": [1, 381], "The": [1, 2, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 43, 47, 57, 58, 64, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 163, 164, 166, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 189, 190, 191, 192, 193, 194, 198, 199, 204, 205, 206, 207, 208, 209, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 244, 252, 253, 254, 255, 256, 257, 258, 259, 260, 262, 263, 264, 265, 266, 267, 268, 269, 271, 272, 276, 278, 282, 283, 284, 285, 288, 289, 290, 291, 293, 294, 295, 297, 299, 301, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 318, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 347, 350, 353, 355, 356, 357, 358, 359, 360, 361, 362, 365, 367, 368, 369, 372, 375, 380, 381, 382, 383, 384, 385, 386, 387, 388], "structur": [1, 363, 381], "from": [1, 4, 5, 6, 97, 99, 122, 123, 125, 126, 130, 145, 148, 158, 164, 166, 175, 180, 182, 183, 184, 185, 187, 190, 200, 213, 215, 217, 220, 221, 226, 235, 237, 239, 240, 241, 250, 267, 276, 278, 291, 307, 308, 309, 310, 312, 313, 326, 335, 350, 379, 380, 381, 383, 384, 385, 386, 387], "frontend": 1, "api": [1, 381], "redirect": 1, "when": [1, 4, 6, 7, 83, 87, 145, 148, 255, 256, 305, 309, 310, 329, 335, 353, 371, 377, 380, 387], "appropri": [1, 380], "fallback": 1, "metal": [1, 6], "vjp": [1, 385], "jvp": [1, 385], "In": [1, 4, 5, 35, 158, 180, 240, 250, 258, 263, 353, 356, 358, 359, 361, 362, 363, 379, 380, 381, 383, 386, 387], "one": [1, 4, 7, 35, 73, 77, 82, 86, 87, 109, 110, 111, 112, 145, 151, 158, 181, 183, 213, 217, 244, 288, 305, 325, 387], "sentenc": 1, "comput": [1, 3, 4, 5, 6, 7, 91, 92, 93, 94, 97, 113, 131, 142, 145, 153, 161, 176, 180, 208, 215, 224, 231, 232, 233, 238, 250, 254, 262, 263, 264, 265, 266, 276, 289, 293, 294, 297, 304, 307, 308, 309, 310, 316, 317, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 355, 356, 357, 359, 360, 361, 362, 366, 380, 381, 385, 387], "graph": [1, 4, 5, 6, 381], "rule": 1, "evalu": [1, 4, 5, 6, 107, 142, 233, 250, 274, 285, 353, 355, 380, 385], "said": [1, 4], "start": [1, 3, 4, 6, 7, 16, 113, 147, 210, 380, 382, 387], "discuss": 1, "more": [1, 2, 5, 9, 73, 99, 158, 166, 167, 198, 199, 244, 250, 254, 258, 297, 301, 304, 305, 307, 308, 309, 310, 377, 380, 381, 382, 385, 387], "detail": [1, 9, 166, 250, 258, 297, 301, 305, 307, 308, 309, 310, 356, 358, 359, 361, 362, 382, 385], "thei": [1, 3, 4, 14, 88, 136, 299, 327, 353, 362, 379, 380, 383, 385, 386, 387], "c": [1, 2, 4, 145, 252, 253, 254, 255, 256, 258, 259, 264, 265, 268, 269, 384, 385, 387], "scalar": [1, 12, 14, 27, 46, 73, 76, 80, 82, 101, 102, 104, 129, 130, 131, 132, 133, 136, 143, 144, 145, 147, 153, 154, 155, 156, 158, 160, 169, 171, 177, 182, 187, 190, 191, 198, 217, 231, 235, 238, 336, 381, 383, 385], "i": [1, 2, 4, 5, 6, 7, 14, 16, 25, 34, 73, 82, 85, 86, 87, 88, 91, 92, 93, 94, 98, 99, 102, 107, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 129, 130, 136, 141, 142, 145, 146, 148, 153, 157, 158, 164, 165, 167, 177, 178, 180, 181, 182, 183, 186, 189, 190, 191, 194, 197, 198, 199, 204, 208, 210, 215, 220, 221, 224, 227, 231, 232, 233, 234, 235, 239, 240, 244, 250, 252, 253, 254, 255, 256, 257, 258, 259, 261, 262, 263, 264, 265, 266, 267, 268, 269, 275, 276, 282, 284, 285, 287, 288, 290, 291, 292, 293, 294, 295, 297, 301, 303, 304, 305, 309, 310, 315, 316, 317, 323, 324, 326, 331, 332, 335, 336, 338, 343, 347, 353, 357, 360, 362, 363, 364, 369, 371, 372, 377, 380, 381, 382, 383, 384, 385, 386, 387, 388], "sum": [1, 3, 12, 94, 135, 145, 157, 208, 224, 250, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 382, 384], "element": [1, 11, 12, 17, 18, 19, 20, 21, 22, 25, 66, 81, 89, 90, 91, 92, 93, 94, 97, 101, 102, 104, 105, 106, 108, 110, 128, 129, 132, 133, 136, 137, 138, 139, 140, 143, 144, 149, 150, 151, 152, 153, 154, 155, 156, 160, 169, 171, 172, 178, 180, 181, 192, 193, 196, 204, 205, 206, 207, 211, 212, 217, 220, 222, 223, 226, 231, 235, 257, 258, 259, 262, 265, 270, 292, 295, 297, 319, 321, 322, 337, 338, 340, 343, 344, 345, 380, 381], "wise": [1, 11, 12, 17, 18, 19, 20, 21, 22, 81, 89, 90, 101, 102, 104, 105, 106, 108, 128, 129, 132, 133, 136, 143, 144, 149, 150, 151, 152, 153, 154, 155, 156, 160, 169, 171, 172, 192, 196, 204, 205, 206, 207, 211, 212, 217, 222, 223, 258, 259, 270, 292, 319, 321, 322, 337, 338, 340, 343, 344, 345, 380], "numpi": [1, 4, 5, 6, 12, 14, 16, 80, 101, 102, 104, 132, 133, 136, 143, 144, 153, 158, 160, 169, 171, 217, 383, 385, 386], "style": [1, 12, 14, 101, 102, 104, 132, 133, 136, 143, 144, 153, 158, 160, 169, 171, 217], "broadcast": [1, 12, 14, 80, 82, 101, 102, 104, 130, 132, 133, 136, 143, 144, 153, 158, 160, 169, 171, 182, 183, 190, 191, 217, 221, 235, 291], "between": [1, 6, 82, 127, 304, 324, 327, 328, 331, 371, 383, 387], "input": [1, 3, 4, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 98, 99, 101, 102, 104, 105, 106, 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 131, 132, 133, 135, 136, 137, 138, 139, 140, 142, 143, 144, 145, 146, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 168, 169, 170, 171, 172, 175, 176, 177, 178, 179, 180, 181, 189, 192, 193, 194, 195, 196, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 229, 230, 231, 232, 234, 235, 237, 252, 253, 254, 255, 256, 258, 259, 260, 262, 263, 264, 265, 266, 267, 268, 269, 291, 293, 294, 295, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 318, 323, 324, 326, 327, 328, 329, 331, 332, 334, 336, 347, 350, 380, 381, 382, 385, 386], "upcast": 1, "const": [1, 326], "factor": [1, 146, 305, 325, 370, 373], "streamordevic": 1, "stream": [1, 6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 96, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 186, 187, 189, 190, 191, 192, 193, 194, 195, 196, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 232, 235, 236, 237, 387], "schedul": [1, 167, 355, 369, 370, 371, 372, 373, 375, 387], "itself": [1, 364], "call": [1, 4, 5, 28, 129, 250, 260, 276, 288, 299, 353, 355, 364, 380, 381, 383], "other": [1, 4, 6, 141, 145, 250, 277, 353, 362, 380, 382, 383, 385], "within": [1, 2, 25, 136], "simplest": [1, 250], "wai": [1, 4, 7, 250, 305, 380, 381, 382], "about": [1, 4, 5, 383, 387], "term": [1, 326, 356, 357, 358, 359, 360, 361, 367], "exist": [1, 4, 276, 288], "auto": [1, 2, 7], "ax": [1, 13, 15, 23, 24, 74, 109, 116, 117, 119, 120, 122, 123, 125, 126, 127, 135, 145, 157, 159, 161, 168, 177, 179, 208, 213, 218, 219, 224, 227, 232, 381], "multipli": [1, 35, 180, 181, 257, 301, 305], "earlier": 1, "goal": 1, "themselv": [1, 380], "contain": [1, 4, 25, 26, 64, 83, 99, 121, 122, 123, 145, 154, 155, 156, 180, 210, 235, 250, 275, 277, 278, 284, 304, 332, 350, 353, 380, 381], "act": [1, 331], "data": [1, 5, 6, 9, 16, 110, 124, 125, 130, 134, 147, 174, 190, 228, 236, 259, 306, 307, 308, 309, 310, 311, 312, 313, 380, 382, 384], "nor": [1, 131, 231], "rather": [1, 381, 387], "easi": [1, 250], "interfac": 1, "block": [1, 4, 304], "A": [1, 4, 6, 7, 8, 64, 76, 83, 111, 112, 114, 131, 142, 145, 146, 148, 157, 158, 159, 168, 180, 182, 183, 184, 186, 187, 190, 191, 210, 214, 216, 231, 233, 234, 238, 239, 240, 241, 242, 250, 254, 258, 262, 263, 264, 266, 275, 279, 280, 283, 289, 290, 294, 299, 301, 304, 307, 308, 310, 317, 336, 337, 353, 355, 359, 361, 363, 364, 366, 371, 380, 381, 383, 384], "It": [1, 4, 7, 131, 203, 231, 250, 290, 293, 363, 375, 384, 386], "creat": [1, 4, 7, 110, 134, 216, 250, 353, 355, 371, 380, 382, 384], "output": [1, 4, 7, 13, 14, 15, 16, 25, 80, 83, 91, 92, 93, 94, 110, 111, 112, 113, 114, 121, 124, 125, 126, 130, 131, 134, 136, 145, 147, 157, 159, 161, 168, 174, 175, 178, 179, 182, 183, 184, 186, 187, 190, 191, 200, 201, 208, 213, 218, 221, 228, 231, 232, 233, 234, 235, 236, 237, 252, 253, 254, 255, 256, 264, 267, 268, 269, 291, 293, 303, 304, 305, 307, 308, 309, 310, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 347, 350, 380, 381, 382, 383, 384, 385, 386, 387], "given": [1, 13, 15, 25, 35, 80, 82, 84, 91, 92, 93, 94, 97, 99, 107, 109, 115, 116, 117, 118, 119, 120, 124, 125, 126, 130, 145, 157, 159, 161, 166, 168, 173, 179, 187, 195, 203, 208, 210, 218, 225, 226, 228, 229, 230, 232, 242, 252, 253, 257, 268, 269, 275, 291, 324, 326, 332], "set": [1, 4, 5, 7, 83, 100, 103, 111, 113, 166, 167, 202, 203, 216, 261, 266, 267, 274, 276, 283, 284, 285, 288, 289, 293, 297, 303, 324, 336, 347, 353, 357, 364, 377, 381, 383], "further": [1, 7, 381], "class": [1, 4, 5, 8, 9, 10, 27, 242, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 325, 353, 356, 357, 358, 359, 360, 361, 362, 367, 368, 375], "under": [1, 145], "These": [1, 83, 221, 325, 387], "word": 1, "bit": [1, 97, 180, 181, 244, 271, 293, 294], "abstract": 1, "back": [1, 4, 165, 384], "give": [1, 4, 5, 25, 380], "ourselv": 1, "concret": [1, 262, 265, 267, 295, 383, 387], "imag": [1, 256, 258, 259, 305], "public": [1, 250], "explicit": [1, 364, 377, 384], "alpha_": 1, "beta_": 1, "must": [1, 7, 82, 130, 145, 182, 183, 187, 190, 191, 235, 305, 384], "know": [1, 4], "popul": 1, "To": [1, 3, 4, 5, 7, 166, 250, 350, 380, 381, 385], "avoid": [1, 283, 380], "unnecessari": [1, 4], "alloc": [1, 163, 166, 167, 353], "respons": 1, "space": [1, 147, 334], "void": 1, "eval_cpu": 1, "std": [1, 312], "overrid": [1, 103], "eval_gpu": 1, "jacobian": [1, 142, 233, 385], "product": [1, 93, 135, 142, 158, 176, 179, 224, 233, 291, 385], "primal": [1, 142, 233], "tangent": [1, 21, 22, 142, 222, 223, 348], "int": [1, 2, 4, 5, 8, 13, 15, 16, 23, 24, 25, 26, 30, 31, 32, 33, 37, 38, 39, 40, 41, 42, 45, 52, 53, 54, 55, 56, 59, 62, 64, 67, 70, 71, 72, 73, 75, 80, 84, 85, 86, 87, 91, 92, 93, 94, 97, 98, 99, 109, 110, 113, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 130, 131, 134, 141, 145, 147, 157, 159, 161, 162, 163, 164, 166, 167, 168, 170, 174, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194, 195, 208, 209, 210, 213, 214, 218, 219, 220, 221, 224, 225, 226, 227, 228, 229, 230, 231, 232, 234, 236, 242, 250, 252, 253, 254, 255, 256, 260, 262, 263, 264, 265, 266, 267, 268, 269, 291, 293, 294, 295, 297, 301, 304, 318, 324, 325, 329, 334, 336, 353, 369, 371, 372, 373], "argnum": [1, 131, 231, 381], "cotan": 1, "across": [1, 263], "pair": [1, 177, 278, 297], "repres": [1, 4, 332, 336, 384], "axi": [1, 4, 5, 13, 15, 23, 24, 25, 26, 30, 31, 32, 33, 37, 38, 39, 40, 52, 53, 54, 55, 59, 67, 70, 71, 75, 84, 91, 92, 93, 94, 99, 109, 111, 112, 115, 118, 121, 122, 123, 124, 125, 126, 127, 145, 157, 159, 161, 168, 170, 177, 178, 179, 183, 193, 208, 209, 210, 213, 214, 218, 219, 220, 221, 225, 226, 227, 232, 234, 252, 253, 268, 269, 295, 318, 322, 324, 325, 329, 334, 336, 344, 382], "correspond": [1, 13, 15, 73, 82, 97, 99, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 157, 159, 168, 179, 218, 224, 234, 240, 381], "dimens": [1, 4, 13, 15, 23, 24, 58, 64, 73, 77, 78, 79, 83, 86, 87, 99, 109, 113, 122, 123, 125, 126, 127, 135, 145, 146, 157, 158, 159, 161, 168, 179, 180, 183, 189, 218, 221, 224, 227, 232, 254, 255, 256, 258, 259, 262, 263, 264, 265, 266, 291, 294, 295, 297, 304, 305, 318, 325, 380, 381], "vmap": [1, 381, 383, 385], "print": [1, 3, 4, 5, 7, 239, 240, 241, 250, 377, 380, 381, 382, 383, 384, 385], "ostream": 1, "o": [1, 7, 114, 265], "equival": [1, 28, 61, 74, 102, 129, 220, 261, 290, 292, 293, 296, 298, 300, 302], "check": [1, 7, 76, 141, 165, 278, 381, 382], "bool": [1, 13, 14, 15, 23, 24, 30, 31, 32, 33, 37, 38, 39, 40, 52, 53, 54, 55, 59, 71, 73, 75, 76, 83, 87, 91, 92, 93, 94, 113, 136, 141, 145, 148, 157, 159, 161, 165, 167, 168, 179, 181, 218, 232, 254, 255, 256, 262, 263, 264, 265, 266, 267, 271, 275, 276, 278, 283, 285, 288, 291, 293, 295, 297, 301, 304, 305, 323, 326, 357, 368], "is_equival": 1, "privat": 1, "fall": 1, "eval": [1, 2, 3, 4, 5, 250, 353, 355, 380, 381, 383, 385], "deriv": [1, 381, 383], "base": [1, 113, 145, 150, 152, 297, 304, 353, 355, 361, 375, 377, 380, 382], "abov": [1, 4, 180, 229, 250, 305, 360, 381, 382, 383, 387], "demonstr": [1, 384], "treat": [1, 122, 123, 125, 126, 220, 305, 380], "paramet": [1, 3, 4, 5, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 166, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 271, 272, 275, 276, 278, 283, 284, 285, 288, 289, 290, 291, 292, 293, 294, 295, 297, 299, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 318, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 347, 349, 350, 353, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 366, 367, 368, 369, 370, 371, 372, 373, 375, 380, 381, 383], "produc": [1, 83, 291, 350], "through": [1, 215, 304, 362, 380, 381, 384], "construct": [1, 5, 41, 98, 130, 174, 225, 236], "its": [1, 7, 158, 178, 189, 228, 238, 241, 250, 293, 359, 360, 361, 384, 387], "type": [1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 64, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 166, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 189, 190, 191, 192, 193, 194, 195, 196, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 239, 250, 283, 304, 306, 307, 308, 309, 310, 311, 312, 313, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 380, 382], "shape": [1, 4, 5, 61, 76, 80, 83, 85, 86, 87, 99, 114, 115, 118, 121, 124, 125, 126, 130, 142, 158, 174, 175, 182, 183, 184, 186, 187, 190, 191, 194, 221, 233, 235, 236, 237, 250, 252, 253, 254, 255, 256, 258, 259, 262, 264, 265, 267, 268, 269, 278, 295, 306, 307, 308, 309, 310, 311, 312, 313, 325, 336, 355, 380, 381, 382, 385, 387], "pass": [1, 4, 5, 61, 74, 176, 177, 231, 238, 239, 240, 250, 276, 288, 289, 290, 293, 299, 380, 383], "re": [1, 5, 7, 350], "now": [1, 4, 7, 293, 380, 384], "promot": 1, "dtype": [1, 4, 10, 16, 27, 34, 35, 73, 110, 127, 130, 134, 141, 145, 146, 147, 174, 184, 186, 187, 190, 191, 228, 236, 244, 283, 305, 306, 307, 308, 309, 310, 311, 312, 313, 323, 325, 332, 369, 370, 371, 372, 373, 380, 381, 382, 384, 385, 386], "promoted_dtyp": 1, "promote_typ": 1, "float32": [1, 2, 10, 16, 110, 114, 134, 141, 145, 146, 147, 174, 184, 186, 190, 191, 228, 236, 244, 305, 306, 307, 308, 309, 310, 311, 312, 313, 323, 325, 332, 369, 370, 371, 372, 373, 380, 381, 382, 383, 384, 385, 386], "non": [1, 7, 286, 295, 337, 353], "point": [1, 3, 4, 7, 129, 181, 244], "out_dtyp": 1, "is_floating_point": 1, "cast": [1, 34, 124, 125, 126, 148, 271, 283, 384], "up": [1, 4, 293, 380], "determin": [1, 99, 244, 282, 386], "x_cast": 1, "astyp": [1, 4, 271, 384], "y_cast": 1, "broadcasted_input": 1, "broadcast_arrai": 1, "out_shap": 1, "0": [1, 3, 4, 5, 7, 8, 16, 35, 41, 42, 45, 62, 67, 75, 84, 85, 86, 87, 98, 99, 110, 114, 127, 131, 145, 146, 166, 177, 182, 186, 191, 193, 195, 210, 214, 228, 229, 230, 231, 232, 234, 239, 250, 252, 253, 254, 255, 256, 257, 258, 259, 261, 263, 264, 266, 268, 269, 292, 296, 297, 301, 302, 303, 304, 306, 307, 308, 309, 310, 311, 312, 313, 314, 316, 317, 319, 320, 323, 325, 327, 328, 332, 335, 336, 338, 339, 340, 341, 346, 347, 350, 353, 356, 357, 359, 360, 361, 362, 364, 367, 368, 369, 370, 371, 372, 373, 377, 380, 381, 382, 383, 384, 385, 386], "unique_ptr": 1, "make_shar": 1, "to_stream": 1, "handl": [1, 250, 380], "resolv": 1, "No": [1, 4], "happen": [1, 4, 111, 304, 355, 380, 383], "alon": [1, 384], "effect": [1, 258, 380, 383], "onli": [1, 4, 6, 7, 76, 85, 86, 87, 145, 180, 250, 275, 276, 278, 283, 285, 288, 289, 290, 353, 380, 381, 386, 387], "execut": [1, 7, 77, 78, 79, 164, 384, 387], "depend": [1, 2, 3, 73, 145, 262, 265, 295, 382, 386, 387], "devic": [1, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 186, 187, 189, 190, 191, 192, 193, 194, 195, 196, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 232, 235, 236, 237, 242, 387, 388], "specifi": [1, 16, 34, 86, 87, 99, 122, 123, 130, 131, 145, 147, 170, 174, 183, 193, 219, 220, 221, 224, 227, 231, 234, 236, 254, 303, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 347, 381, 387], "memori": [1, 6, 162, 163, 164, 166, 167, 304, 353, 357, 380, 383, 384], "ha": [1, 2, 4, 5, 6, 73, 83, 99, 121, 122, 124, 125, 126, 131, 163, 183, 254, 262, 265, 267, 295, 353, 355, 380, 382, 383, 385, 387], "been": [1, 4, 163, 383], "try": [1, 7], "naiv": [1, 381], "gener": [1, 2, 3, 10, 16, 87, 110, 122, 123, 147, 182, 186, 187, 190, 191, 304, 377, 380, 382, 383, 388], "version": [1, 7, 97, 153, 157, 180, 208, 234, 377, 381, 382], "declar": 1, "member": [1, 250, 281, 286], "method": [1, 4, 8, 9, 10, 27, 242, 250, 282, 353, 356, 357, 358, 359, 360, 361, 362, 364, 367, 368, 375], "each": [1, 64, 97, 107, 113, 141, 158, 177, 180, 181, 183, 193, 200, 201, 210, 225, 227, 234, 235, 258, 259, 260, 262, 263, 265, 295, 297, 304, 323, 325, 377, 380, 383], "find": [1, 3, 7], "pointwis": 1, "captur": [1, 2, 83, 250, 380], "templat": 1, "axpby_impl": 1, "typenam": 1, "t": [1, 4, 105, 114, 181, 231, 250, 252, 262, 265, 268, 295, 356, 357, 358, 359, 360, 361, 362, 367, 368, 380, 381, 387], "readi": 1, "fill": [1, 130, 175, 228, 237, 306, 307, 308, 309, 310, 312, 313], "malloc_or_wait": 1, "synchron": [1, 380], "avail": [1, 3, 4, 5, 7, 9, 165, 387], "There": [1, 250, 305, 380], "wait": [1, 4, 167], "here": [1, 4, 380, 381, 383, 386, 387], "request": 1, "pressur": 1, "condit": [1, 235, 387], "set_data": 1, "nbyte": 1, "collect": [1, 240, 379], "pointer": 1, "x_ptr": 1, "y_ptr": 1, "out_ptr": 1, "relev": 1, "static_cast": 1, "size_t": 1, "out_idx": 1, "size": [1, 4, 5, 47, 64, 86, 97, 109, 111, 112, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 130, 134, 141, 145, 163, 167, 180, 181, 183, 194, 210, 213, 250, 252, 253, 255, 256, 260, 264, 268, 269, 293, 305, 357, 383, 384], "map": [1, 5, 35, 148, 240, 260, 271], "linear": [1, 4, 5, 6, 240, 250, 261, 278, 293, 295, 296, 298, 300, 305, 314, 315, 316, 317, 318, 320, 339, 340, 341, 343, 350, 353, 364, 372, 380], "indic": [1, 14, 23, 24, 25, 26, 35, 131, 136, 137, 138, 139, 140, 210, 220, 221, 231, 285, 287, 325, 332, 371, 382], "offset": [1, 4, 42, 99, 111, 113], "x_offset": 1, "elem_to_loc": 1, "stride": [1, 85, 86, 87, 252, 253, 255, 256, 268, 269, 297, 382], "y_offset": 1, "contigu": 1, "regularli": 1, "default": [1, 7, 13, 14, 15, 16, 23, 24, 25, 26, 76, 83, 84, 85, 86, 87, 95, 96, 97, 98, 99, 110, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 131, 134, 136, 145, 146, 147, 148, 157, 159, 161, 166, 167, 168, 174, 178, 179, 180, 181, 182, 183, 184, 186, 187, 189, 190, 191, 193, 194, 195, 202, 203, 209, 210, 213, 214, 216, 218, 224, 226, 227, 228, 229, 230, 231, 232, 234, 236, 244, 252, 253, 254, 255, 256, 262, 264, 265, 267, 268, 269, 271, 276, 278, 283, 285, 288, 291, 292, 293, 295, 297, 301, 302, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 318, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 353, 356, 357, 358, 359, 360, 361, 362, 367, 368, 369, 377, 379, 380, 381, 384, 386, 388], "row": [1, 110, 134, 180, 228], "major": 1, "henc": [1, 180, 380], "doesn": [1, 250], "addit": [1, 4, 12, 111, 112, 114, 148, 254, 263, 266, 291, 294, 353, 381], "abl": [1, 180], "all": [1, 2, 5, 7, 14, 25, 35, 77, 78, 79, 83, 86, 87, 110, 117, 120, 123, 126, 158, 177, 178, 213, 250, 271, 272, 276, 279, 280, 281, 286, 288, 291, 293, 301, 304, 305, 350, 353, 375, 377, 380, 382, 383, 385, 388], "incom": 1, "accordingli": 1, "dispatch": 1, "float16": [1, 10, 148, 244, 271, 383, 384], "bfloat16": [1, 10, 244, 384], "complex64": [1, 244], "throw": [1, 83], "error": [1, 7, 105, 106, 167, 210, 261, 293, 315, 316, 317, 331, 333, 381, 384], "encount": [1, 381], "unexpect": [1, 16], "regist": [1, 5], "op": [1, 176, 276, 383], "assert": 1, "2": [1, 3, 4, 5, 35, 86, 98, 99, 105, 116, 119, 121, 122, 123, 124, 125, 126, 127, 141, 145, 146, 152, 158, 180, 189, 224, 228, 229, 230, 244, 250, 252, 253, 256, 261, 268, 269, 294, 301, 305, 306, 307, 308, 309, 310, 311, 312, 313, 316, 325, 326, 328, 335, 336, 350, 353, 356, 358, 359, 360, 364, 367, 380, 381, 382, 383, 384, 385, 386, 387], "1": [1, 2, 4, 5, 16, 25, 26, 35, 42, 45, 85, 86, 87, 98, 99, 114, 115, 116, 118, 119, 121, 122, 123, 124, 125, 126, 127, 135, 141, 145, 146, 158, 167, 176, 178, 180, 183, 186, 191, 204, 209, 220, 226, 231, 244, 250, 252, 253, 254, 255, 256, 257, 258, 259, 261, 262, 263, 264, 265, 266, 267, 268, 269, 292, 294, 295, 297, 301, 303, 305, 307, 308, 309, 310, 311, 312, 313, 314, 316, 317, 318, 321, 322, 323, 324, 325, 326, 327, 328, 329, 331, 332, 334, 335, 336, 341, 342, 344, 345, 347, 350, 353, 355, 356, 357, 358, 359, 360, 361, 362, 364, 367, 368, 369, 370, 371, 372, 373, 380, 381, 382, 384, 385, 386, 387], "correct": [1, 7, 359, 360, 361, 382, 383], "els": [1, 4, 250, 276, 383], "float16_t": 1, "bfloat16_t": 1, "complex64_t": 1, "runtime_error": 1, "support": [1, 4, 6, 7, 14, 85, 86, 87, 114, 127, 136, 146, 148, 158, 180, 381, 382, 384, 386], "have": [1, 4, 7, 14, 76, 77, 78, 79, 122, 123, 125, 126, 136, 158, 183, 239, 265, 291, 299, 362, 364, 379, 380, 382, 383, 387], "rememb": 1, "3": [1, 4, 7, 127, 141, 145, 146, 305, 308, 310, 319, 357, 362, 377, 380, 382, 384, 385], "complic": 1, "keep": [1, 13, 15, 23, 24, 157, 159, 161, 168, 179, 218, 232, 250, 275, 381, 383], "mind": [1, 4], "half": [1, 16, 187, 191, 297, 383], "precis": [1, 4, 114, 250, 261, 294, 363, 380], "direct": [1, 4, 273, 362, 387], "fix": [1, 4, 7, 383], "possibli": [1, 4, 158], "due": 1, "transpos": [1, 4, 28, 181], "aren": 1, "guarante": 1, "fit": [1, 180, 387], "requir": [1, 4, 250, 383, 384], "column": [1, 110, 134, 180], "inplac": 1, "expect": [1, 4, 255, 256, 257, 258, 259, 301, 304, 326, 380, 382], "answer": 1, "copi": [1, 4, 6, 178, 209, 384], "simpli": [1, 4, 7, 296, 314, 320, 339, 348, 353, 380, 381], "catlas_saxpbi": 1, "axpby_impl_acceler": 1, "first": [1, 2, 3, 4, 5, 7, 99, 127, 131, 154, 156, 158, 178, 189, 219, 224, 231, 239, 250, 253, 263, 269, 305, 324, 332, 357, 359, 360, 361, 364, 380, 381, 384, 387], "mode": [1, 88, 274, 285, 287, 305, 309, 310], "e": [1, 5, 7, 105, 142, 204, 254, 255, 256, 258, 259, 263, 264, 266, 276, 294, 321, 322, 344, 349, 355, 358, 380, 383, 388], "match": [1, 7, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 162, 278, 305, 325, 382, 384], "transposit": 1, "data_s": 1, "items": 1, "flag": [1, 380, 384], "copy_inplac": 1, "copytyp": 1, "n": [1, 4, 27, 85, 86, 87, 110, 115, 117, 118, 120, 121, 124, 126, 134, 228, 232, 252, 253, 254, 255, 256, 258, 259, 262, 265, 268, 269, 295, 305, 331, 336], "incx": 1, "inci": 1, "great": [1, 2], "But": [1, 387], "criteria": 1, "luckili": [1, 383], "alwai": [1, 162, 239, 381], "With": 1, "final": [1, 3, 4, 5, 372], "singl": [1, 5, 107, 142, 148, 177, 233, 253, 269, 380, 382, 386], "row_contigu": 1, "col_contigu": 1, "common": [1, 355, 380, 383], "hit": 1, "mileston": 1, "enough": [1, 383], "run": [1, 2, 4, 5, 6, 7, 8, 176, 242, 254, 271, 356, 357, 359, 360, 361, 380, 383, 387, 388], "If": [1, 4, 7, 13, 14, 15, 16, 23, 24, 25, 26, 73, 76, 82, 84, 88, 91, 92, 93, 94, 98, 99, 107, 111, 113, 124, 125, 126, 129, 130, 131, 136, 145, 148, 157, 158, 159, 161, 166, 167, 168, 174, 177, 178, 179, 183, 193, 208, 209, 210, 218, 220, 221, 224, 226, 231, 232, 234, 236, 240, 254, 255, 256, 263, 266, 267, 276, 278, 288, 293, 295, 297, 299, 301, 305, 323, 325, 336, 357, 380, 381, 383, 386, 387, 388], "plan": [1, 380], "stop": [1, 4, 16, 147, 215, 381, 382], "enjoi": 1, "speed": 1, "appl": [1, 4, 6, 7, 387], "silicon": [1, 4, 6, 7, 387], "address": 1, "shade": 1, "languag": 1, "kernel": [1, 85, 86, 87, 252, 253, 268, 269, 380, 382], "written": 1, "help": [1, 4, 380, 387], "resourc": 1, "walkthrough": 1, "pipelin": 1, "specif": [1, 7, 381], "cpp": 1, "algorithm": [1, 305, 362], "launch": [1, 382], "exactli": [1, 4, 278, 381], "mani": [1, 210, 255, 256, 260, 380, 383], "thread": 1, "pick": 1, "updat": [1, 3, 4, 5, 35, 83, 240, 254, 271, 272, 278, 283, 284, 285, 290, 355, 357, 360, 362, 363, 364, 368, 369, 370, 371, 372, 373, 380, 383], "assign": [1, 35, 353], "axpby_gener": 1, "buffer": [1, 162, 384], "constant": [1, 4, 7, 111, 112, 177, 250, 254, 263, 266, 294, 326, 336, 367, 369, 380, 384], "4": [1, 4, 97, 127, 145, 180, 181, 200, 244, 252, 253, 254, 264, 268, 269, 293, 304, 305, 307, 308, 309, 323, 380, 382, 385, 387], "5": [1, 3, 4, 7, 145, 167, 182, 252, 254, 257, 258, 259, 264, 268, 302, 305, 306, 309, 310, 335, 346, 350, 367, 369, 370, 380, 381, 382], "x_stride": 1, "6": [1, 4, 145, 200, 304, 308, 316, 317, 319, 326, 336, 340, 367, 380, 382, 385], "y_stride": 1, "7": [1, 4, 145, 180, 382], "ndim": [1, 127, 145, 305], "8": [1, 4, 7, 145, 180, 244, 253, 264, 269, 304, 324, 356, 357, 358, 359, 360, 361, 367, 380, 382, 385, 387], "uint": 1, "index": [1, 6, 8, 25, 35, 109, 110, 131, 178, 220, 221, 231, 242], "thread_position_in_grid": 1, "convert": [1, 73, 77, 78, 79, 127, 293, 383, 384, 385], "instanti": [1, 5, 383], "uniqu": [1, 377], "host": 1, "name": [1, 148, 180, 181, 198, 199, 200, 201, 250, 263, 275, 278, 280, 382, 386], "identifi": [1, 239, 379], "instantiate_axpbi": 1, "type_nam": 1, "host_nam": 1, "axpby_general_": 1, "compil": [1, 2, 6, 7, 100, 103, 381, 383], "mlx_ext": 1, "metallib": [1, 7], "see": [1, 4, 5, 7, 9, 10, 29, 30, 31, 32, 33, 36, 37, 38, 39, 40, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 145, 166, 198, 199, 244, 250, 254, 258, 261, 274, 292, 293, 296, 297, 298, 300, 301, 302, 305, 307, 308, 309, 310, 315, 316, 317, 341, 380, 381, 382, 385, 387], "later": [1, 2, 7], "co": [1, 301, 381], "locat": [1, 289, 290, 387], "share": [1, 6, 97, 180, 181], "register_librari": 1, "potenti": [1, 167], "path": [1, 2, 7, 200, 201, 278], "tri": 1, "load": [1, 5, 6, 278], "hasn": 1, "alreadi": [1, 4], "static": [1, 7], "object": [1, 2, 9, 27, 46, 73, 83, 141, 200, 234, 239, 240, 244, 258, 304, 379], "why": [1, 4], "packag": [1, 3, 5, 350], "process": [1, 4, 87, 88, 240, 259, 260, 304, 379], "logic": [1, 154, 155, 156], "grid": 1, "shown": 1, "below": [1, 7, 145, 228, 230, 244, 305, 383], "prepar": [1, 4], "carri": 1, "should": [1, 3, 4, 5, 7, 99, 111, 112, 114, 142, 180, 221, 231, 233, 239, 250, 255, 256, 258, 259, 285, 291, 299, 325, 327, 332, 353, 379, 380, 381, 383, 384, 388], "d": [1, 4, 98, 99, 135, 145, 158, 176, 220, 228, 229, 230, 241, 259, 262, 265, 295, 356, 359, 361, 387], "ostringstream": 1, "kname": 1, "axpby_": 1, "general_": 1, "type_to_nam": 1, "make": [1, 4, 5, 7, 158, 173, 203, 250, 369, 370, 372, 373, 380, 383, 385, 387], "sure": [1, 4, 7, 250, 380], "look": [1, 4], "folder": 1, "get_colocated_mtllib_path": 1, "get_kernel": 1, "str": [1, 88, 131, 145, 148, 197, 198, 199, 200, 201, 231, 239, 241, 271, 272, 275, 276, 278, 280, 282, 288, 305, 309, 310, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336], "encod": [1, 113, 297, 301, 304, 325], "compute_encod": 1, "get_command_encod": 1, "setcomputepipelinest": 1, "those": [1, 4, 250], "nelem": 1, "set_array_buff": 1, "setbyt": 1, "sizeof": 1, "threadgroup": 1, "higher": [1, 135, 332, 381], "than": [1, 4, 73, 88, 99, 102, 113, 132, 133, 143, 144, 158, 166, 240, 297, 303, 305, 332, 335, 347, 357, 362, 380, 381, 387], "max": [1, 145, 160, 268, 269, 292, 319, 324, 326, 327, 332, 336, 338, 340, 357, 361, 380, 381, 387], "allow": [1, 141, 250, 290, 353, 375, 382, 385], "tgp_size": 1, "min": [1, 145, 169, 292, 319, 338, 340], "maxtotalthreadsperthreadgroup": 1, "3d": [1, 254, 259, 305], "mtl": 1, "group_dim": 1, "grid_dim": 1, "divid": [1, 35, 129, 180], "among": 1, "dispatchthread": 1, "few": [1, 4, 5, 6, 383, 385], "thing": [1, 4], "note": [1, 4, 7, 14, 83, 85, 86, 114, 122, 123, 136, 145, 162, 180, 183, 250, 294, 305, 384, 386], "befor": [1, 4, 7, 25, 178, 275, 304, 364, 382, 383], "move": [1, 170, 387], "track": [1, 250, 254], "activ": [1, 7, 162, 258, 303, 304, 337, 346, 347, 349, 380], "u": [1, 267, 290, 375, 383], "command": [1, 2, 7], "instead": [1, 7, 250, 290, 301, 381, 383], "end_encod": 1, "end": [1, 99, 165, 180, 253, 262, 265, 269, 303, 328, 335, 341, 346, 347, 372], "until": [1, 383, 385], "limit": [1, 82, 166, 167, 382], "flush": 1, "enqueu": 1, "commit": 1, "associ": [1, 200, 201, 383], "suggest": 1, "deeper": 1, "dive": 1, "studi": 1, "come": [1, 4, 381], "far": [1, 355], "built": [1, 7, 383], "includ": [1, 91, 92, 93, 94, 162, 163, 167, 266, 272, 284, 293, 326, 380, 381, 382, 385, 386, 388], "forward": [1, 231, 380, 383], "diff": 1, "push": 1, "along": [1, 23, 24, 83, 84, 91, 92, 93, 94, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 145, 193, 208, 210, 214, 220, 221, 224, 225, 226, 250, 295, 318], "similarli": [1, 7, 158, 381, 383], "scale_arr": 1, "contribut": 1, "tangent_x": 1, "tangent_i": 1, "revers": [1, 37, 38, 39, 40, 91, 92, 93, 94, 227, 301], "arg": [1, 4, 9, 10, 107, 200, 201], "push_back": 1, "fulli": [1, 6, 380, 384, 387], "overal": 1, "directori": [1, 4, 7], "extens": [1, 148, 282, 386], "h": [1, 85, 86, 145, 253, 254, 256, 258, 259, 262, 265, 269, 295, 381, 383], "mlx_sample_extens": 1, "__init__": [1, 4, 5, 8, 9, 10, 27, 242, 250, 353], "py": [1, 4, 7], "cmakelist": 1, "txt": 1, "setup": [1, 3, 5, 7, 380], "hold": [1, 4, 9, 10, 145, 380], "instal": 1, "pybind11": 1, "sinc": [1, 4, 5, 353, 362, 371, 384, 387], "compon": [1, 4], "etc": [1, 180, 250, 305], "pybind11_modul": 1, "m": [1, 4, 7, 110, 145, 228, 252, 253, 268, 269, 356, 380], "doc": [1, 5], "sampl": [1, 3, 4, 147, 182, 183, 184, 187, 190, 191, 307, 308, 309, 310, 312, 313, 326, 332, 336, 377, 380], "_a": 1, "pos_onli": 1, "kw_onli": 1, "none": [1, 4, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 199, 200, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 234, 235, 236, 237, 239, 240, 242, 252, 253, 261, 268, 269, 271, 275, 276, 283, 288, 291, 295, 301, 304, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 357, 375, 382], "r": [1, 4, 146, 231, 258, 262], "pbdoc": 1, "most": [1, 183, 250, 366, 380, 381, 382, 383], "complex": [1, 122, 123, 124, 125, 126, 239, 244, 250, 290, 380, 381], "bell": 1, "whistl": 1, "liter": [1, 305, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336], "string": [1, 384, 386], "modul": [1, 4, 5, 238, 293, 299, 304, 350, 366, 379, 380, 383], "ensur": [1, 7, 331], "caster": 1, "find_packag": 1, "config": 1, "add_librari": 1, "sourc": [1, 2, 56, 170, 227], "target_sourc": 1, "cmake_current_list_dir": 1, "header": 1, "target_include_directori": 1, "target_link_librari": 1, "attach": 1, "conveni": [1, 5, 141], "mlx_build_metallib": 1, "target": [1, 231, 323, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 380], "destin": [1, 56, 170], "automat": [1, 6, 148, 385, 386, 387], "practic": [1, 380], "mlx_build_met": [1, 7], "mlx_ext_metallib": 1, "titl": 1, "include_dir": 1, "project_source_dir": 1, "mlx_include_dir": 1, "output_directori": 1, "cmake_library_output_directori": 1, "add_depend": 1, "endif": 1, "pybind11_add_modul": 1, "build_shared_lib": 1, "target_link_opt": 1, "wl": 1, "rpath": 1, "loader_path": 1, "onc": [1, 380], "describ": [1, 383], "util": [1, 4, 6, 7, 200, 250], "__name__": [1, 4], "__main__": [1, 4], "descript": [1, 4, 244], "ext_modul": 1, "cmakeextens": 1, "cmdclass": 1, "build_ext": 1, "cmakebuild": 1, "package_dir": 1, "package_data": 1, "dylib": 1, "zip_saf": 1, "fals": [1, 4, 13, 14, 15, 23, 24, 30, 31, 32, 33, 37, 38, 39, 40, 52, 53, 54, 55, 59, 71, 75, 76, 83, 87, 91, 92, 93, 94, 136, 141, 145, 148, 157, 159, 161, 167, 168, 179, 218, 232, 235, 239, 240, 244, 263, 264, 266, 267, 276, 278, 288, 291, 293, 297, 301, 304, 305, 323, 326, 357, 368, 384], "python_requir": 1, "even": [1, 4, 83, 380, 383, 384], "though": [1, 4, 380, 383, 384], "j8": 1, "libmlx_ext": 1, "cpython": 1, "3x": 1, "darwin": 1, "pip": [1, 7], "after": [1, 4, 5, 25, 127, 129, 178, 180, 254, 263, 266, 271, 272, 276, 278, 285, 288, 289, 290, 291, 304, 335, 380, 387], "plai": [1, 4], "ones": [1, 4, 175, 200, 228, 289, 290, 293, 382], "b": [1, 2, 4, 12, 14, 76, 101, 102, 104, 129, 132, 133, 135, 136, 143, 144, 145, 153, 154, 156, 158, 160, 169, 171, 176, 180, 217, 224, 231, 267, 295, 305, 318, 381, 382, 383, 384, 385, 386, 387], "f": [1, 2, 3, 5, 145, 250, 265, 360, 380, 384], "item": [1, 3, 4, 5, 240, 383, 384, 385], "true": [1, 3, 4, 14, 37, 38, 39, 40, 76, 83, 91, 92, 93, 94, 113, 136, 141, 145, 148, 167, 181, 208, 235, 239, 240, 244, 250, 254, 255, 256, 262, 263, 264, 265, 266, 267, 275, 276, 278, 285, 288, 293, 295, 297, 301, 304, 305, 323, 331, 357], "quick": [1, 6], "benchmark": [1, 380], "compar": [1, 76, 380], "time": [1, 4, 7, 167, 225, 250, 252, 253, 262, 265, 268, 269, 295, 380, 381, 383, 387], "set_default_devic": 1, "256": [1, 5], "512": [1, 4, 304, 387], "random": [1, 3, 4, 5, 6, 252, 253, 254, 264, 268, 269, 278, 285, 380, 381, 387, 388], "normal": [1, 3, 4, 111, 112, 190, 250, 252, 253, 254, 263, 264, 266, 268, 269, 294, 304, 307, 309, 384, 387], "bench": 1, "warm": [1, 380], "rang": [1, 3, 4, 5, 7, 16, 127, 147, 308, 310, 316, 317, 355, 369, 370, 371, 372, 373, 377, 380, 381, 383, 387], "100": [1, 3, 4, 372, 380, 381, 383, 387], "5000": 1, "simple_tim": 1, "custom_tim": 1, "3f": [1, 5, 380], "custom": [1, 304], "114": 1, "109": 1, "modest": 1, "improv": [1, 2, 4, 356, 357, 358, 359, 360, 361, 367, 380], "awai": [1, 4], "good": [1, 7, 380, 387], "nn": [1, 4, 5, 200, 240, 250, 350, 353, 355, 364, 366, 380, 383], "grad": [1, 3, 5, 231, 355, 363, 380, 381, 382, 383, 385], "full": [1, 5, 61, 74, 88, 208, 289, 290, 326, 380, 383], "profil": 2, "kei": [2, 4, 114, 182, 183, 184, 186, 187, 189, 190, 191, 239, 240, 275, 276, 288, 291, 364, 377, 379, 381], "build": [2, 4, 6, 309, 353, 380], "mlx_metal_debug": [2, 7], "option": [2, 4, 13, 15, 16, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 77, 78, 79, 83, 84, 85, 86, 87, 88, 91, 92, 93, 94, 97, 98, 99, 110, 111, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 130, 131, 134, 139, 140, 145, 146, 147, 148, 157, 159, 161, 167, 168, 174, 177, 178, 179, 180, 181, 182, 183, 184, 186, 187, 189, 190, 191, 193, 194, 208, 209, 210, 213, 214, 218, 220, 224, 226, 227, 228, 229, 230, 231, 232, 234, 236, 239, 240, 252, 253, 254, 255, 256, 262, 265, 267, 268, 269, 271, 275, 276, 278, 283, 288, 291, 293, 295, 297, 301, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 356, 357, 358, 359, 360, 361, 362, 364, 367, 368, 369, 377, 380, 386, 388], "debug": 2, "record": [2, 164, 383], "dure": [2, 83, 257, 258, 259, 305, 384], "inspect": [2, 380, 385], "label": [2, 3, 325, 332], "queue": 2, "readabl": 2, "start_captur": 2, "initi": [2, 3, 4, 250, 254, 263, 264, 266, 267, 292, 294, 306, 307, 308, 309, 310, 311, 312, 313, 353, 364, 369, 370, 372, 373, 380, 383], "gpu": [2, 6, 380, 382, 387], "main": [2, 6, 99, 110, 240, 250], "jane": 2, "develop": [2, 6, 7], "gputrac": 2, "arang": [2, 145, 244, 305, 382, 384], "10": [2, 4, 5, 150, 195, 200, 240, 250, 278, 350, 371, 373, 380, 382], "20": [2, 145], "30": [2, 357], "40": 2, "stop_captur": 2, "replai": 2, "trace": [2, 380], "view": [2, 384], "overview": 2, "oper": [2, 4, 6, 8, 34, 77, 78, 79, 87, 114, 208, 215, 221, 242, 250, 304, 362, 380, 381, 382, 383, 384, 385, 387, 388], "checkout": [2, 380], "document": [2, 6, 61, 74, 198, 199, 244, 380, 381, 382], "inform": [2, 4, 5, 7, 198, 199, 244, 250, 254, 261, 291, 381, 387], "skip": 2, "save": [2, 4, 6, 148, 180, 198, 199, 200, 201, 282, 383], "project": [2, 4, 291], "us": [2, 3, 4, 5, 6, 7, 16, 35, 97, 100, 102, 113, 127, 145, 146, 158, 162, 163, 164, 166, 180, 181, 193, 194, 239, 244, 250, 253, 258, 260, 261, 262, 265, 267, 269, 271, 275, 282, 289, 291, 293, 295, 297, 301, 304, 305, 309, 310, 316, 317, 324, 350, 353, 355, 356, 357, 359, 360, 361, 362, 363, 364, 377, 379, 380, 381, 382, 385, 387], "cmake": [2, 7], "mkdir": [2, 7], "cd": [2, 7], "dmlx_metal_debug": 2, "ON": [2, 7], "g": [2, 7, 145, 180, 265, 349, 367, 368, 383, 388], "xcodeproj": 2, "select": [2, 7, 226, 235, 271, 275, 283], "metal_captur": 2, "exampl": [2, 3, 4, 5, 16, 35, 127, 145, 146, 216, 220, 250, 252, 253, 254, 264, 268, 269, 276, 278, 285, 288, 305, 306, 307, 308, 309, 310, 311, 312, 313, 323, 325, 332, 350, 355, 364, 369, 370, 371, 372, 373, 377, 381, 382, 383, 384, 385, 386], "schema": 2, "implement": [3, 5, 113, 114, 145, 260, 275, 291, 297, 299, 301, 303, 304, 305, 347, 356, 357, 358, 359, 361, 362, 363, 375, 380, 381, 384], "basic": [3, 195, 381], "model": [3, 5, 6, 200, 238, 240, 250, 271, 274, 276, 278, 282, 285, 287, 288, 289, 291, 304, 350, 353, 355, 363, 364, 366, 380, 383], "problem": [3, 5, 250], "metadata": [3, 148, 198, 199], "num_featur": [3, 254], "num_exampl": 3, "1_000": 3, "num_it": 3, "10_000": 3, "iter": [3, 5, 240, 377, 380, 383], "sgd": [3, 5, 355, 362, 364, 369, 370, 373, 380], "lr": [3, 362], "01": [3, 320, 360], "rate": [3, 356, 357, 358, 359, 360, 361, 362, 367, 368], "ll": [3, 5, 328, 380, 381], "synthet": 3, "dataset": [3, 383], "matrix": [3, 41, 97, 98, 110, 134, 145, 146, 158, 180, 181, 293, 311, 350], "ground": [3, 4, 325, 335], "truth": [3, 325, 335], "w_star": 3, "valu": [3, 4, 11, 14, 16, 23, 24, 46, 73, 76, 82, 110, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 130, 136, 145, 147, 177, 182, 183, 184, 186, 187, 190, 191, 198, 220, 221, 231, 234, 238, 239, 240, 244, 253, 257, 258, 259, 264, 267, 269, 275, 291, 292, 302, 303, 304, 306, 323, 324, 325, 326, 327, 328, 330, 331, 332, 333, 334, 335, 347, 353, 357, 360, 369, 370, 372, 373, 381], "gaussian": [3, 261, 315, 316, 317, 326], "nois": 3, "noisi": 3, "ep": [3, 111, 112, 254, 263, 264, 266, 294, 324, 326, 336, 356, 357, 358, 359, 360, 361, 367], "1e": [3, 5, 14, 136, 254, 263, 264, 266, 294, 324, 326, 336, 356, 357, 358, 359, 360, 361, 364, 367, 369, 370, 371, 372, 373], "weight": [3, 85, 86, 87, 111, 112, 240, 250, 278, 282, 293, 323, 325, 353, 357, 360, 362, 364, 368, 381, 383], "squar": [3, 4, 112, 134, 196, 211, 231, 240, 250, 294, 333, 335, 356, 357, 359, 360, 361, 381, 384], "loss": [3, 5, 231, 250, 355, 380, 381, 383], "loss_fn": [3, 5, 355, 380, 381], "w": [3, 86, 97, 180, 181, 231, 253, 254, 256, 258, 259, 267, 269, 368, 381], "mean": [3, 4, 5, 112, 186, 231, 250, 254, 263, 276, 294, 312, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 380, 381, 384], "grad_fn": [3, 380, 381], "randomli": [3, 4, 257, 258, 259], "Then": [3, 7], "repeatedli": 3, "_": [3, 4, 250, 369, 370, 371, 372, 373, 377, 380, 383, 387], "verifi": [3, 7], "close": [3, 6, 7, 14, 136], "error_norm": 3, "5f": 3, "someth": [3, 4, 382], "00005": 3, "00364": 3, "complet": [3, 4, 7, 167, 289, 290, 381, 387], "logist": [3, 204, 316, 317, 343], "github": [3, 5, 7, 380], "repo": [3, 5, 7, 380], "enabl": [4, 7, 83, 103, 368], "larg": [4, 250, 291, 331, 380, 383], "ish": 4, "transform": [4, 6, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 238, 250, 254, 263, 266, 267, 275, 276, 288, 293, 297, 382], "compromis": 4, "eas": 4, "llama": 4, "famili": 4, "less": [4, 25, 144, 178, 297, 335], "200": [4, 371], "line": [4, 383, 384], "python": [4, 46, 64, 73, 107, 239, 240, 241, 353, 363, 364, 366, 379, 381, 384], "neural": [4, 6, 260, 307, 308, 337, 350, 353, 367], "network": [4, 6, 254, 258, 260, 307, 308, 350, 353, 367], "concis": 4, "architectur": [4, 7, 250, 290, 387], "notabl": [4, 6], "rope": [4, 250], "posit": [4, 25, 99, 113, 127, 131, 140, 170, 178, 231, 240, 250, 255, 256, 291, 297, 301, 326, 336], "cach": [4, 162, 163, 166, 380], "concaten": 4, "llamaattent": 4, "self": [4, 5, 8, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 242, 250, 337, 353], "dim": [4, 113, 114, 260, 263, 264, 266, 291, 294, 297, 301, 304], "num_head": [4, 291, 304], "super": [4, 5, 250, 353], "tradit": [4, 113, 258, 259, 297], "query_proj": 4, "bia": [4, 97, 111, 180, 181, 240, 250, 255, 256, 262, 265, 266, 267, 276, 278, 288, 291, 293, 295, 359, 360, 361, 364, 381], "key_proj": 4, "value_proj": 4, "out_proj": [4, 353], "__call__": [4, 5, 250, 353], "queri": [4, 114, 291], "mask": [4, 114, 285, 291, 382], "extract": [4, 41, 98, 99, 250, 275, 353], "l": [4, 5, 250, 252, 254, 255, 262, 265, 268, 295, 335], "reshap": [4, 145, 305, 382], "combin": 4, "key_cach": 4, "value_cach": 4, "sqrt": [4, 105, 114, 254, 263, 264, 266, 267, 294, 301, 307, 308, 309, 310, 356, 358, 359, 360, 367, 380], "score": [4, 114, 332], "softmax": [4, 114, 250, 322, 325], "values_hat": 4, "rm": [4, 7, 112, 357], "swiglu": 4, "rmsnorm": [4, 250], "llamaencoderlay": 4, "mlp_dim": [4, 304], "norm1": 4, "norm2": 4, "linear1": 4, "linear2": 4, "linear3": 4, "sigmoid": [4, 250, 300, 316, 317, 321, 343], "instanc": [4, 35, 180, 241, 250, 264, 271, 272, 273, 276, 278, 279, 280, 285, 288, 289, 290, 299, 353, 384], "embed": [4, 250, 297, 301, 324], "emb": [4, 260, 301], "token": [4, 260], "num_lay": [4, 5, 355], "vocab_s": 4, "norm": [4, 112, 263, 336, 361, 362], "multiheadattent": [4, 250], "create_additive_causal_mask": 4, "list": [4, 9, 13, 15, 27, 67, 73, 77, 78, 79, 80, 83, 84, 87, 107, 116, 117, 119, 120, 122, 123, 125, 126, 130, 131, 142, 145, 157, 159, 161, 168, 174, 177, 179, 182, 183, 184, 186, 187, 190, 191, 198, 208, 210, 214, 218, 224, 225, 227, 231, 232, 233, 236, 239, 241, 250, 276, 278, 279, 280, 281, 286, 288, 289, 290, 353, 359, 360, 361, 362, 371, 379, 380, 381, 383], "still": [4, 7, 145, 380, 383], "consid": [4, 14, 76, 136, 239, 240, 263, 379], "train": [4, 5, 250, 254, 257, 258, 259, 274, 276, 288, 307, 308], "ignor": [4, 35, 82, 83, 107, 357], "whatsoev": 4, "rest": [4, 113, 240, 297], "subsect": 4, "prompt": 4, "autoregress": 4, "yield": [4, 5, 377], "temp": 4, "causal": 4, "append": [4, 158, 380, 383], "store": 4, "per": [4, 5, 97, 180, 181, 254, 263, 264, 266, 294, 375, 380, 383], "care": [4, 383], "last": [4, 26, 73, 111, 112, 117, 120, 122, 123, 125, 126, 127, 135, 146, 158, 183, 209, 224, 255, 256, 258, 259, 263, 305, 384], "logit": [4, 183, 323, 325, 380], "next": [4, 5, 166], "categor": 4, "lazili": [4, 250], "noth": [4, 250, 383], "yet": [4, 145, 250, 353, 364, 381, 382, 383, 385], "forc": [4, 5, 250, 385], "choos": [4, 113, 297], "pars": 4, "feed": 4, "loop": [4, 5, 380, 381, 383], "unsqueez": 4, "sequenc": [4, 13, 15, 30, 31, 52, 53, 54, 55, 59, 67, 70, 71, 75, 80, 87, 109, 116, 117, 119, 120, 122, 123, 125, 126, 130, 157, 159, 161, 168, 174, 179, 182, 183, 184, 186, 187, 190, 191, 194, 208, 210, 213, 218, 224, 225, 227, 232, 236, 254, 255, 262, 265, 295, 304, 377, 387], "length": [4, 213, 254, 255, 262, 265, 295, 371], "len": [4, 117, 120, 123, 126, 371], "overwrit": 4, "discard": [4, 239], "old": 4, "moment": [4, 87, 357, 359, 360, 361], "anymor": 4, "everyth": 4, "small": [4, 111, 112, 254, 263, 266, 294, 326, 331, 336, 380, 387], "12": [4, 371], "8192": 4, "1024": 4, "actual": [4, 16, 278, 353, 383], "materi": [4, 6], "could": [4, 250], "20_000": 4, "machin": [4, 6, 7, 367], "8gb": 4, "ram": 4, "32": [4, 5, 180, 181, 244, 253, 269, 294, 380], "44": 4, "doubl": 4, "bracket": 4, "becaus": [4, 162, 250, 383], "batch": [4, 158, 254, 255, 256, 258, 259, 262, 265, 291, 295, 305, 383], "zip": [4, 5], "haven": 4, "anyth": [4, 231, 383], "result": [4, 16, 35, 73, 83, 97, 111, 112, 135, 145, 148, 158, 176, 181, 193, 195, 214, 224, 225, 235, 240, 301, 380, 381, 384], "similar": [4, 141, 240, 289, 290, 291, 324, 384, 386], "runtim": [4, 380], "section": [4, 7, 210, 336, 380, 381], "access": [4, 46, 250, 353, 364, 383, 387], "origin": [4, 99, 254, 284, 307, 308, 309, 310, 356, 357, 358, 359, 361, 362, 384], "sentencepiec": 4, "pytorch": [4, 6, 263, 381], "compat": [4, 183, 386], "npz": [4, 148, 200, 201, 278, 282, 386], "file": [4, 7, 148, 197, 198, 199, 200, 201, 278, 282, 381, 386], "directli": 4, "argpars": 4, "itertool": [4, 240], "starmap": [4, 240], "np": [4, 5, 384, 385], "torch": [4, 384], "map_torch_to_mlx": 4, "tok_embed": 4, "elif": 4, "replac": [4, 289, 290, 304, 335], "attention_norm": 4, "ffn_norm": 4, "wq": 4, "wk": 4, "wv": 4, "wo": 4, "w1": 4, "w2": 4, "w3": 4, "ffn": 4, "separ": [4, 61, 74, 263, 332], "submodul": [4, 5, 250, 272, 276, 277, 288, 290], "feed_forward": 4, "parser": 4, "argumentpars": 4, "add_argu": 4, "torch_weight": 4, "output_fil": 4, "parse_arg": 4, "state": [4, 5, 250, 262, 265, 295, 355, 364, 377, 380], "savez": [4, 282, 386], "k": [4, 41, 98, 110, 114, 226, 228, 229, 230, 252, 267, 268, 276], "v": [4, 88, 114, 250, 276, 384], "left": [4, 113, 145, 180, 252, 253, 261, 268, 269, 297, 305, 316, 317, 326, 328, 336], "disk": 4, "text": [4, 252, 253, 262, 265, 268, 269, 270, 295, 303, 307, 308, 309, 310, 319, 326, 327, 328, 331, 332, 335, 337, 338, 341, 342, 346, 347, 357, 362], "format": [4, 148, 197, 198, 199, 200, 201, 384], "dictionari": [4, 83, 148, 198, 199, 239, 250, 275, 284, 289, 290, 365, 379, 386], "represent": [4, 180, 239, 241], "tree_unflatten": 4, "helper": [4, 380], "weight_fil": 4, "incur": 4, "sever": [4, 85, 86, 87, 200, 201, 380, 386], "futur": [4, 293, 382, 383], "pth": 4, "current": [4, 6, 7, 85, 86, 87, 163, 180, 250, 357, 383], "around": 4, "m1": [4, 380, 381, 387], "ultra": 4, "7b": 4, "me": 4, "ishmael": 4, "year": 4, "ago": 4, "never": [4, 383], "long": 4, "info": [4, 7], "247": 4, "press": [4, 145], "enter": 4, "littl": 4, "monei": 4, "my": [4, 7], "purs": 4, "greater": [4, 25, 133, 178, 303, 347], "consequ": 4, "walk": 4, "down": 4, "gower": 4, "street": 4, "afternoon": 4, "heavi": 4, "rain": 4, "saw": [4, 381], "off": [4, 7, 383], "man": 4, "rag": 4, "who": 4, "sat": 4, "upon": [4, 240], "hi": [4, 265], "bundl": 4, "hard": 4, "wet": 4, "he": [4, 309, 310], "were": [4, 387], "cry": 4, "watch": [4, 380], "him": 4, "observ": 4, "numer": [4, 111, 112, 145, 153, 157, 208, 254, 263, 264, 266, 294, 324, 326, 336, 356, 357, 358, 359, 360, 361, 367, 380, 383], "crowd": 4, "wa": [4, 383], "hurri": 4, "437": 4, "330": 4, "second": [4, 99, 154, 156, 158, 219, 231, 253, 269, 324, 332, 357, 359, 360, 361, 381, 387], "spent": 4, "amount": [4, 164, 252, 268], "39": 4, "By": [4, 283, 381, 384], "bigger": [4, 357], "remain": [4, 231, 257, 258, 259], "almost": 4, "nobodi": 4, "took": 4, "least": [4, 77, 78, 79, 82, 146, 180], "notic": [4, 381, 386], "distanc": [4, 336], "had": 4, "doubt": 4, "minut": 4, "straight": 4, "slowli": 4, "rais": [4, 145, 167, 210, 278], "ey": 4, "speak": [4, 145], "resum": 4, "postur": 4, "stood": 4, "feel": 4, "pain": 4, "heart": 4, "smile": 4, "face": 4, "am": 4, "someon": 4, "three": [4, 79], "quarter": 4, "hour": 4, "made": 4, "immedi": [4, 271], "repli": 4, "again": [4, 7, 250, 380], "hand": [4, 381, 383], "did": 4, "accustom": 4, "thu": [4, 250], "question": [4, 383], "reason": [4, 382], "tell": [4, 380, 384], "understand": [4, 307, 308], "579": 4, "690": 4, "num": [4, 147, 189], "500": [4, 387], "628": 4, "went": 4, "nervou": 4, "trembl": 4, "told": 4, "And": 4, "perhap": 4, "surpris": 4, "matter": [4, 250], "shall": 4, "anyhow": 4, "friend": 4, "ye": 4, "slight": [4, 383], "kind": 4, "longer": [4, 88, 381], "soon": 4, "unless": [4, 14, 136, 145, 353], "unlik": [4, 14, 136, 258, 259, 284], "strang": 4, "amus": 4, "That": 4, "secret": 4, "disappoint": 4, "mine": 4, "cannot": [4, 82, 382, 384], "happi": 4, "ask": 4, "shop": 4, "bui": 4, "food": 4, "633": 4, "21": [4, 373], "475": 4, "su": 4, "j": [4, 7, 145, 258, 358, 359, 361], "lu": 4, "pan": 4, "murtadha": 4, "wen": 4, "liu": 4, "2021": 4, "roform": [4, 297], "enhanc": [4, 297, 383], "rotari": [4, 113, 297], "arxiv": [4, 263, 264, 266, 270, 294, 317, 337, 356, 362], "preprint": [4, 356, 362], "2104": 4, "09864": 4, "zhang": 4, "sennrich": 4, "2019": [4, 360], "root": [4, 112, 196, 211, 294], "advanc": [4, 380], "system": [4, 7, 162, 163], "shazeer": 4, "2020": 4, "glu": [4, 250], "variant": [4, 335, 361], "2002": 4, "05202": 4, "classifi": 5, "mnist": 5, "As": [5, 35, 220, 250, 380], "mlp": [5, 250, 304, 355], "inherit": [5, 379], "standard": [5, 46, 73, 158, 184, 186, 304, 307, 309, 312, 385], "idiom": [5, 380], "input_dim": [5, 250, 267, 293], "hidden_dim": [5, 353, 355], "output_dim": [5, 250, 267, 293], "layer_s": 5, "idim": 5, "odim": 5, "maximum": [5, 23, 35, 82, 91, 164, 167, 250, 296, 301, 316, 317, 320, 339, 353, 383], "cross": [5, 87, 323, 325], "entropi": [5, 323, 325], "sub": [5, 99, 189], "commonli": [5, 289, 350, 380], "cross_entropi": [5, 250], "accuraci": 5, "valid": [5, 88, 127, 234, 239, 276, 288, 379], "eval_fn": 5, "argmax": 5, "loader": 5, "num_class": [5, 355], "batch_siz": [5, 355], "num_epoch": [5, 355], "learning_r": [5, 355, 356, 357, 358, 359, 360, 361, 362, 364, 367, 368, 369, 370, 371, 372, 373, 380], "train_imag": [5, 355], "train_label": [5, 355], "test_imag": 5, "test_label": 5, "shuffl": 5, "minibatch": 5, "batch_iter": [5, 355], "perm": 5, "permut": 5, "id": [5, 7], "put": [5, 380], "trainabl": [5, 238, 250, 353], "loss_and_grad_fn": [5, 355, 380, 381], "value_and_grad": [5, 250, 289, 353, 355, 366, 380, 381, 384, 385], "epoch": 5, "test": [5, 7], "confus": 5, "decent": 5, "95": 5, "brought": 6, "research": 6, "except": [6, 110, 121, 122, 124, 125, 126, 263, 278, 382, 384], "featur": [6, 85, 86, 87, 113, 254, 262, 263, 264, 265, 266, 267, 293, 294, 295, 297, 304, 305, 380, 383], "differ": [6, 141, 217, 335, 381], "lazi": [6, 353, 385], "multi": [6, 114, 255, 256, 382, 384], "cpu": [6, 146, 380, 387], "inspir": 6, "jax": [6, 377], "arrayfir": 6, "unifi": 6, "live": [6, 387], "guid": 6, "convers": 6, "regress": [6, 331], "layer": [6, 111, 250, 252, 253, 258, 259, 262, 263, 265, 266, 267, 268, 269, 285, 290, 293, 295, 299, 304, 349, 353], "perceptron": 6, "llm": 6, "infer": [6, 130, 148], "fft": 6, "algebra": 6, "tree": [6, 83, 107, 131, 231, 234, 239, 240, 241, 363, 364, 366, 375, 381], "debugg": 6, "pypi": 7, "meet": 7, "seri": 7, "chip": 7, "nativ": 7, "maco": 7, "13": 7, "recommend": [7, 167, 362], "14": 7, "sonoma": 7, "conda": 7, "forg": 7, "distribut": [7, 182, 183, 184, 186, 190, 191, 267, 307, 308, 309, 310, 312, 313, 326, 329, 334, 336, 350], "probabl": [7, 187, 257, 258, 259, 293, 323, 325, 329, 387], "platform": 7, "processor": 7, "arm": 7, "i386": 7, "switch": 7, "17": 7, "clang": 7, "24": 7, "xcode": 7, "15": [7, 145, 380], "sdk": 7, "environ": [7, 100, 103], "via": [7, 363, 366, 383, 384], "rosetta": 7, "unam": 7, "p": [7, 182, 250, 257, 258, 259, 336, 359, 361], "clone": 7, "git": 7, "com": 7, "ml": 7, "explor": 7, "nanobind": [7, 304], "http": [7, 263, 264, 266, 270, 294, 317, 337], "wjakob": 7, "env": 7, "cmake_build_parallel_level": 7, "edit": [7, 290], "unittest": 7, "discov": 7, "stub": 7, "dev": 7, "generate_stub": 7, "either": [7, 12, 61, 73, 74, 82, 101, 102, 104, 129, 132, 133, 143, 144, 145, 153, 158, 160, 169, 171, 217, 231, 253, 269, 299, 305, 309, 310], "libmlx": 7, "preprocessor": 7, "metal_path": 7, "mlx_build_test": 7, "mlx_build_exampl": 7, "mlx_build_benchmark": 7, "mlx_build_python_bind": 7, "multipl": [7, 111, 112, 158, 171, 180, 181, 291, 301, 370, 371, 373, 380, 383, 386], "wish": 7, "variabl": [7, 83, 100, 103, 131, 142, 231, 233, 234], "export": 7, "developer_dir": 7, "app": 7, "content": [7, 275, 380], "xcrun": 7, "macosx": 7, "show": [7, 244, 380], "unabl": 7, "tool": 7, "sudo": 7, "ouptut": 7, "finder": 7, "iterm": 7, "termin": 7, "click": 7, "uncheck": 7, "window": [7, 252, 253, 268, 269], "restart": 7, "grep": 7, "cmake_host_system_processor": 7, "arm64": 7, "x86_64": 7, "wipe": 7, "cahc": 7, "rf": 7, "devicetyp": 8, "attribut": [8, 9, 27, 242, 284, 353, 375], "kwarg": [9, 10, 200, 201, 388], "categori": [10, 244], "bool_": [10, 244], "integ": [10, 129, 141, 145, 177, 180, 181, 182, 187, 210, 224, 234, 244, 260, 283, 371, 382], "unsignedinteg": 10, "uint8": [10, 244], "uint16": [10, 244], "uint32": [10, 23, 24, 25, 26, 183, 244], "uint64": [10, 244], "signedinteg": [10, 141], "int8": [10, 244], "int32": [10, 16, 35, 127, 141, 145, 187, 244, 305, 382, 385], "int64": [10, 244], "inexact": [10, 141], "complexflo": 10, "complex128": 10, "issubdtyp": [10, 244], "absolut": [11, 14, 136, 316, 317, 335], "semant": [12, 80, 101, 102, 104, 132, 133, 143, 144, 153, 158, 160, 169, 171, 217, 387], "keepdim": [13, 15, 23, 24, 30, 31, 32, 33, 52, 53, 54, 55, 59, 71, 75, 145, 157, 159, 161, 168, 179, 208, 218, 232], "reduct": [13, 15, 157, 159, 168, 179, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336], "reduc": [13, 15, 23, 24, 157, 159, 161, 168, 179, 218, 232, 254, 304, 331], "unspecifi": [13, 15, 16, 23, 24, 25, 26, 84, 91, 92, 93, 94, 130, 157, 159, 161, 168, 174, 178, 179, 193, 208, 209, 218, 220, 226, 232, 236, 388], "entir": [13, 15, 23, 24, 157, 159, 161, 168, 179, 218, 232, 258, 259], "singleton": [13, 15, 23, 24, 157, 158, 159, 161, 168, 179, 218, 232], "rtol": [14, 136], "05": [14, 136, 254, 263, 264, 266, 294], "atol": [14, 136], "08": [14, 136, 324, 358, 359, 360, 361, 367], "equal_nan": [14, 76, 136], "approxim": [14, 261, 315, 316, 317], "comparison": [14, 104, 132, 133, 143, 144], "infinit": [14, 136], "equal": [14, 25, 76, 110, 133, 136, 144, 178, 187, 210, 264, 267], "sign": [14, 136, 244, 362], "nan": [14, 76, 136, 138], "ab": [14, 136, 145, 231, 263, 264, 266, 270, 294, 317, 337, 380], "array_equ": [14, 136], "rel": [14, 136, 357, 380], "toler": [14, 136], "boolean": [14, 76, 136, 137, 138, 139, 140, 154, 155, 156, 244, 287, 382], "interv": [16, 147, 187, 191], "increment": 16, "otherwis": [16, 87, 167, 239, 240, 276, 278, 288, 303, 304, 305, 323, 328, 335, 346, 347, 383, 384], "convent": [16, 88, 305, 360], "lead": [16, 380], "fraction": 16, "integr": [16, 220, 383], "invers": [17, 18, 19, 20, 21, 22, 106, 118, 119, 120, 121, 122, 123], "cosin": [17, 18, 89, 90, 324, 369, 371, 381], "hyperbol": [18, 20, 22, 90, 207, 223, 348], "sine": [19, 20, 206, 207, 381], "minimum": [24, 35, 82, 92, 301, 324, 369], "kth": [25, 178], "partit": 25, "order": [25, 87, 145, 178, 180, 226, 250, 263, 289, 299, 364, 380, 381], "undefin": [25, 178, 382], "sort": [25, 26, 178, 226], "flatten": [25, 26, 91, 92, 93, 94, 145, 176, 178, 193, 209, 220, 221, 226, 239], "dimension": [27, 111, 112, 115, 116, 117, 118, 119, 120, 124, 125, 126, 252, 253, 254, 255, 256, 260, 267, 268, 269, 293, 301, 382, 384], "val": [27, 130], "tupl": [27, 61, 64, 74, 84, 86, 87, 102, 107, 109, 142, 145, 146, 177, 180, 194, 213, 231, 233, 239, 240, 241, 252, 253, 256, 268, 269, 278, 280, 299, 305, 357, 359, 360, 361, 362, 379, 381], "ndarrai": [27, 382, 383, 385], "properti": [28, 35, 43, 47, 57, 58, 64, 66, 284, 287, 365, 381], "argument": [28, 61, 74, 83, 107, 131, 231, 240, 250, 305, 377, 381, 386, 387, 388], "union": [29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 77, 78, 79, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 139, 140, 189, 190, 216], "appli": [35, 113, 114, 240, 250, 252, 253, 254, 255, 256, 258, 259, 261, 263, 264, 266, 267, 268, 269, 270, 272, 285, 292, 293, 294, 295, 296, 298, 300, 302, 303, 305, 314, 315, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 350, 363, 366, 372, 375, 380], "regular": [35, 258, 337, 360, 380, 382], "idx": [35, 382], "correctli": 35, "syntax": [35, 382], "subtract": 35, "inclus": [37, 38, 39, 40, 91, 92, 93, 94, 127], "diagon": [41, 98, 110, 228, 229, 230], "axis1": [42, 72, 99, 219], "axis2": [42, 72, 99, 219], "start_axi": [45, 127], "end_axi": [45, 127], "datatyp": 47, "byte": [47, 57, 162, 163, 164, 166, 167, 244], "decim": [62, 195], "indices_or_sect": [67, 210], "nest": [73, 83, 250, 353, 379, 381], "ddof": [75, 232], "ari": [77, 78, 79], "a_min": 82, "a_max": 82, "edg": [82, 177, 305, 380], "At": 82, "anoth": [82, 141, 158, 217, 235, 244, 250, 271, 380, 381, 382, 387], "fun": [83, 131, 142, 231, 233, 234, 380, 382, 383, 387], "callabl": [83, 131, 142, 231, 233, 234, 238, 239, 240, 271, 272, 275, 283, 295, 299, 304, 306, 307, 308, 309, 310, 311, 312, 313, 356, 357, 358, 359, 360, 361, 362, 367, 368, 369, 370, 371, 372, 373], "shapeless": 83, "dict": [83, 107, 148, 198, 199, 200, 281, 286, 289, 290, 353, 363, 364, 366, 379, 381, 386], "arbitrarili": [83, 250, 379, 381, 385], "leaf": [83, 239, 240, 275], "node": [83, 107, 234], "recompil": [83, 380], "chang": [83, 203, 289, 293, 305, 328, 335, 380, 384], "Not": [83, 380], "attempt": 83, "pad": [85, 86, 87, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 252, 253, 255, 256, 268, 269], "dilat": [85, 86, 87, 255, 256], "group": [85, 86, 87, 97, 114, 180, 181, 263, 293], "1d": [85, 87, 88, 198, 221], "convolut": [85, 86, 87, 88, 255, 256, 258, 259], "channel": [85, 86, 87, 254, 255, 256, 258, 259], "c_in": [85, 86, 87], "c_out": [85, 86, 87], "convolv": [85, 86, 87], "2d": [86, 87, 99, 180, 254, 258], "spatial": [86, 87, 252, 263, 268, 305], "symmetr": 86, "kernel_dil": 87, "input_dil": 87, "flip": [87, 88], "correl": [87, 258], "discret": [88, 115, 116, 117, 118, 119, 120, 124, 125, 126, 260], "swap": [88, 167, 219, 290, 293], "conv": 88, "filter": [88, 255, 256, 271, 275], "signal": [88, 305], "cumul": [91, 92, 93, 94], "th": [91, 92, 93, 94, 98, 110, 371], "bias": [97, 180, 181, 262, 265, 276, 288, 291], "group_siz": [97, 180, 181, 293], "64": [97, 180, 181, 244, 293], "configur": 97, "formal": [97, 180], "notat": [97, 239, 280], "quantiz": [97, 148, 181, 293], "w_i": [97, 180], "hat": [97, 180], "occupi": [97, 180, 181], "subarrai": [99, 210], "remov": [99, 158, 183, 213, 325], "insert": [99, 109, 387], "neg": [99, 127, 139, 268, 269, 291, 326, 334, 336, 382], "taken": [99, 220], "global": [100, 103, 188, 377, 380], "disabl": [100, 166, 380], "mlx_disable_compil": [100, 103, 380], "divis": [101, 129, 180], "quotient": [101, 102, 129], "remaind": 102, "fuction": 102, "faster": [102, 315, 380, 381], "mathrm": [105, 204, 264], "frac": [105, 180, 204, 252, 253, 254, 257, 258, 259, 263, 264, 266, 267, 268, 269, 294, 307, 308, 309, 310, 324, 326, 328, 331, 342, 344, 356, 358, 359, 360, 361, 367], "pi": [105, 301, 381], "int_0": 105, "dt": 105, "erf": [106, 380], "exponenti": [108, 298, 314, 341, 370], "ident": [110, 215, 250, 285], "zero": [110, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 228, 229, 230, 237, 250, 252, 253, 257, 258, 259, 278, 306, 307, 308, 309, 310, 311, 312, 313, 350, 357, 382], "whose": [110, 238], "translat": [111, 266], "stabil": [111, 112, 254, 263, 264, 266, 294, 324, 326, 356, 357, 358, 359, 360, 361, 367], "traditino": 113, "rotat": [113, 297], "larger": [113, 297, 362], "unchang": [113, 215, 297], "consecut": [113, 180, 297], "angular": [113, 297], "frequenc": [113, 297, 301], "q": [114, 146], "head": [114, 291, 304], "attent": [114, 276, 291, 301, 304], "regardless": 114, "pre": 114, "tile": 114, "typic": [114, 260, 355, 380, 383], "One": [115, 118, 124, 196, 380, 381], "fourier": [115, 116, 117, 118, 119, 120, 124, 125, 126], "truncat": [115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 190], "dft": [115, 116, 117, 118, 119, 120, 124, 125, 126], "rfft": 121, "real": [121, 122, 123, 124, 125, 126], "rfft2": 122, "rfftn": 123, "silent": [124, 125, 126], "outsid": 127, "clamp": 127, "floor": 129, "argnam": [131, 231], "neither": [131, 231], "keyword": [131, 200, 201, 231, 240, 250, 377, 386, 388], "strict": [132, 143, 276, 278, 288], "ordinari": 135, "inifn": 137, "infin": [137, 139, 140, 268, 269, 361], "dtypecategori": [141, 244], "subtyp": [141, 244], "subdtyp": 141, "float64": 141, "too": [141, 380, 383], "ord": 145, "tabl": [145, 244, 260], "frobeniu": 145, "matric": [145, 146], "strictli": 145, "mathemat": 145, "variou": 145, "purpos": 145, "calcul": [145, 326, 332, 357], "fro": 145, "inf": [145, 291], "largest": [145, 226], "sing": 145, "smallest": 145, "singular": 145, "nuclear": 145, "_f": 145, "sum_": [145, 252, 253, 331], "a_": 145, "valueerror": [145, 278, 381], "refer": [145, 264, 270, 284, 307, 308, 309, 310, 317, 337, 382], "golub": 145, "van": 145, "loan": 145, "baltimor": 145, "md": 145, "john": 145, "hopkin": 145, "univers": 145, "1985": 145, "pg": 145, "la": 145, "9": [145, 325, 356, 359, 360, 361, 362, 364, 370, 373, 384], "74597": 145, "84804": 145, "41421": 145, "23607": [145, 146], "74166": 145, "24264": 145, "11": 145, "225": 145, "894427": 146, "447214": 146, "57771": 146, "50": 147, "evenli": 147, "return_metadata": 148, "binari": [148, 197, 198, 199, 200, 201, 303, 323, 347, 380], "npy": [148, 197, 386], "safetensor": [148, 199, 278, 282, 383, 386], "gguf": [148, 198, 386], "matadata": 148, "unsupport": 148, "tensor": [148, 224, 252, 253, 268, 269, 336, 384], "natur": [149, 151, 383], "logarithm": [149, 150, 151, 152], "log": [151, 153, 157, 321, 322, 326, 329, 331, 334, 345], "plu": 151, "exp": [153, 157, 184, 208, 314, 329, 341, 342, 345, 380, 387], "stabl": [153, 157, 208, 331], "prepend": 158, "report": [162, 167], "peak": 164, "begin": [164, 180, 253, 262, 265, 269, 303, 328, 335, 341, 346, 347], "program": 164, "free": 166, "reclaim": 166, "set_memory_limit": 166, "previou": [166, 167], "relax": 167, "task": [167, 331], "exceed": 167, "negat": 172, "beforehand": 176, "pad_with": 177, "constant_valu": 177, "pad_width": 177, "before_1": 177, "after_1": 177, "before_2": 177, "after_2": 177, "before_n": 177, "after_n": 177, "before_i": 177, "after_i": 177, "extend": 177, "side": [177, 252, 253, 268, 269, 380], "smaller": [178, 362, 380], "everi": [180, 240, 373, 381], "particular": [180, 263], "w_1": 180, "w_g": 180, "align": [180, 253, 262, 265, 269], "max_i": 180, "min_i": 180, "textrm": [180, 261, 315, 318], "round": 180, "pack": [180, 181], "unsign": [180, 181, 244], "lower": [180, 187, 190, 191, 228, 313], "upper": [180, 187, 190, 191, 313], "1st": 180, "signific": 180, "2nd": 180, "dequant": 180, "w_q": 180, "whether": [181, 262, 265, 275, 291, 295, 323, 326, 332], "prng": [182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 377], "num_sampl": 183, "unnorm": [183, 323, 325], "draw": 183, "cdf": [184, 261, 315], "accord": [184, 235, 291, 307, 308, 309, 310], "seed": 185, "loc": 186, "deviat": [186, 307, 309, 312], "low": [187, 191, 313, 350], "high": [187, 191, 250, 260, 313, 350], "bound": [187, 190, 191, 261, 313, 380, 382, 387], "roadcast": 187, "domain": 190, "uniformli": 191, "repetit": 193, "preserv": [194, 381], "reciproc": 196, "arr": [197, 382], "obj": 198, "uncompress": 200, "my_path": 200, "tree_flatten": [200, 240, 241, 250], "transformerencod": 200, "128": [200, 250], "flat_param": 200, "compress": 201, "possibl": [210, 260, 380, 382, 387], "being": [215, 250], "prevent": [215, 336, 384], "flow": [215, 383], "streamcontext": 216, "context": 216, "manag": [216, 377, 387], "prior": [220, 221], "exclud": 221, "dot": [224, 239, 280, 291], "rep": 225, "repeat": 225, "necessarili": 226, "elsewher": [228, 382], "col": 228, "triangl": 228, "mse": 231, "param": [231, 250, 350, 381], "lvalu": 231, "dlvalu": 231, "dparam": 231, "lasso": 231, "l1": [231, 328, 330, 331, 335], "varianc": [232, 254, 263, 326], "divisor": 232, "cotang": 233, "in_ax": [234, 381], "out_ax": [234, 381], "prefix": [234, 239], "fn": [238, 240, 385], "wrt": 238, "is_leaf": [239, 240], "arbitrari": [239, 353], "depth": [239, 259, 381], "hello": [239, 241], "charact": 239, "flat": [239, 241], "superset": [240, 363], "extra": 240, "closer": 240, "constitut": 240, "dict_kei": [240, 364], "lambda": [240, 250, 271, 276, 283, 302, 341, 346, 356, 357, 358, 359, 360, 361, 362, 367, 368, 380, 381], "recreat": 241, "world": 241, "42": 241, "16": [244, 252, 264, 268, 271, 353], "int16": 244, "brain": 244, "e8": 244, "m7": 244, "ieee": 244, "e5": 244, "m10": 244, "hierarchi": 244, "done": [250, 257, 294, 380, 383, 384], "manual": 250, "explicitli": [250, 377], "solv": 250, "intuit": 250, "freez": [250, 288, 353], "finetun": 250, "in_dim": [250, 353], "out_dim": [250, 353], "enumer": 250, "caus": [250, 380, 383], "local": [250, 258], "scope": 250, "l2_loss": 250, "y_hat": 250, "trainable_paramet": [250, 275, 364], "loss_and_grad": 250, "workhors": 250, "Its": 250, "recurs": [250, 275, 276, 281, 286, 288, 353], "frozen": [250, 276, 286, 288, 293, 353], "individu": [250, 258, 259], "subset": [250, 275], "action": 250, "displai": 250, "tree_map": 250, "count": [250, 371], "num_param": 250, "preclud": 250, "pure": [250, 355], "pattern": [250, 383], "achiev": 250, "other_input": 250, "necessari": 250, "wrap": 250, "apply_to_modul": [250, 276], "children": 250, "filter_and_map": 250, "leaf_modul": 250, "load_weight": [250, 383], "named_modul": 250, "save_weight": 250, "set_dtyp": 250, "unfreez": [250, 276], "update_modul": 250, "alibi": 250, "avgpool1d": 250, "avgpool2d": 250, "batchnorm": 250, "conv1d": 250, "conv2d": 250, "dropout": [250, 258, 259, 285, 304, 380], "dropout2d": 250, "dropout3d": 250, "gelu": [250, 316, 317, 380], "groupnorm": 250, "gru": 250, "instancenorm": 250, "layernorm": 250, "lstm": 250, "maxpool1d": 250, "maxpool2d": [250, 253], "mish": 250, "prelu": 250, "quantizedlinear": 250, "relu": [250, 292, 304, 338, 350], "rnn": [250, 262], "selu": 250, "sequenti": [250, 350], "silu": 250, "sinusoidalpositionalencod": 250, "softshrink": 250, "upsampl": 250, "elu": [250, 341], "gelu_approx": [250, 261, 315], "gelu_fast_approx": [250, 261, 315], "hardswish": 250, "leaky_relu": 250, "log_sigmoid": 250, "log_softmax": 250, "relu6": 250, "softplu": [250, 270, 337], "tanh": [250, 262, 265, 270, 295, 337], "binary_cross_entropi": [250, 380], "cosine_similarity_loss": 250, "gaussian_nll_loss": 250, "hinge_loss": 250, "huber_loss": 250, "kl_div_loss": 250, "l1_loss": 250, "log_cosh_loss": 250, "margin_ranking_loss": 250, "mse_loss": 250, "nll_loss": 250, "smooth_l1_loss": 250, "triplet_loss": 250, "init": [250, 292, 350, 355, 369, 370, 372, 373], "uniform": [250, 267, 278, 308, 310, 350, 377, 380, 381, 387], "glorot_norm": 250, "glorot_uniform": 250, "he_norm": 250, "he_uniform": 250, "kernel_s": [252, 253, 255, 256, 268, 269], "averag": [252, 253, 356, 357, 359, 360, 361], "pool": [252, 253, 268, 269, 387], "l_": [252, 268, 328], "n_i": [252, 253, 268, 269], "c_j": [252, 253, 268, 269], "ldot": [252, 253, 268, 269], "lfloor": [252, 253, 268, 269], "_size": [252, 253, 268, 269], "rfloor": [252, 253, 268, 269], "k_h": [253, 269], "k_w": [253, 269], "h_": [253, 262, 265, 269, 295], "w_": [253, 262, 265, 269, 295, 356, 357, 358, 359, 360, 361, 362, 367, 368], "height": [253, 254, 256, 258, 259, 269], "width": [253, 254, 256, 258, 259, 269, 293], "momentum": [254, 362, 364, 368, 380], "affin": [254, 263, 264, 266, 267, 293], "track_running_stat": 254, "var": [254, 263, 264, 266, 326], "epsilon": [254, 263, 264, 266, 294, 324, 326, 356, 358, 359, 360, 361, 367], "gamma": [254, 263, 264, 266, 294, 307, 308, 309, 310], "nc": 254, "nlc": [254, 255], "four": 254, "nhwc": [254, 256], "paper": [254, 301, 356, 357, 358, 359, 361, 362], "deep": [254, 307, 308, 309, 310], "intern": 254, "covari": 254, "shift": 254, "bn": 254, "in_channel": [255, 256], "out_channel": [255, 256], "learnabl": [255, 256, 299], "portion": 257, "independ": [258, 259], "nwhc": 258, "whc": 258, "maintain": [258, 259, 362], "entri": [258, 259], "benefici": [258, 259, 383], "earli": 258, "adjac": 258, "pixel": 258, "thompson": 258, "goroshin": 258, "jain": 258, "lecun": 258, "bregler": 258, "2015": [258, 359, 361], "cvpr": 258, "ndhwc": 259, "dhwc": 259, "medic": 259, "video": 259, "num_embed": 260, "lookup": 260, "usual": [260, 379, 383], "vocabulari": 260, "approx": 261, "unit": [261, 262, 296, 298, 300, 307, 308, 309, 310, 314, 315, 316, 317, 318, 320, 339, 340, 341, 343], "phi": [261, 315], "geluapprox": 261, "sigma": [261, 262, 265, 307, 308, 309, 310, 316, 317, 318, 321, 342, 343], "60033": [261, 316], "0433603": [261, 316], "gelufast": 261, "773": 261, "regard": 261, "input_s": [262, 265, 295], "hidden_s": [262, 265, 295], "gate": [262, 318], "recurr": [262, 265, 295], "nld": [262, 265, 295], "ld": [262, 265, 295], "r_t": 262, "xr": 262, "x_t": [262, 265, 295], "hr": 262, "h_t": [262, 265, 295], "b_": [262, 265], "z_t": 262, "xz": 262, "hz": 262, "n_t": 262, "xn": 262, "odot": [262, 265], "hn": 262, "hidden": [262, 265, 295, 304], "nh": [262, 265, 295], "nlh": [262, 265, 295], "lh": [262, 265, 295], "num_group": 263, "pytorch_compat": 263, "split": [263, 318], "preced": 263, "org": [263, 264, 266, 270, 294, 317, 337], "1803": 263, "08494": 263, "denomin": [264, 324, 356, 358, 359, 360, 361, 367], "inorm": 264, "1607": [264, 266], "08022": 264, "i_t": 265, "xi": 265, "f_t": 265, "xf": 265, "hf": 265, "g_t": [265, 356, 358, 359, 360, 361, 362, 367, 368], "xg": 265, "hg": 265, "o_t": 265, "xo": 265, "ho": 265, "c_": [265, 362], "c_t": [265, 362], "cell": 265, "06450": 266, "mathcal": 267, "d_i": 267, "max_": [268, 269], "1908": [270, 337], "08681": [270, 337], "map_fn": [271, 275], "filter_fn": [271, 275], "valid_parameter_filt": 271, "apply_fn": 272, "descend": 273, "is_leaf_fn": 275, "found": 275, "drop": 275, "idempot": [276, 288], "endswith": 276, "file_or_weight": 278, "miss": [278, 386], "ok": [278, 381], "save_safetensor": [282, 386], "predic": 283, "reflect": [284, 380, 382, 384], "certain": [285, 380], "ie": 288, "noop": 288, "unfrozen": 288, "tracer": 289, "partial": [289, 290, 380, 383], "child": 290, "flexibli": 290, "programmat": 290, "query_input_dim": 291, "key_input_dim": 291, "value_input_dim": 291, "value_dim": 291, "value_output_dim": 291, "aggreg": 291, "linearli": 291, "attend": 291, "num_paramet": 292, "25": [292, 305], "parametr": [292, 338], "classmethod": 293, "from_linear": 293, "quantize_modul": 293, "accumul": 294, "1910": 294, "07467": 294, "nonlinear": [295, 380], "elman": 295, "ih": 295, "hh": 295, "func": 295, "rectifi": [296, 309, 310, 320, 339, 340], "10000": 297, "slightli": [297, 387], "plain": 299, "known": [300, 343], "swish": [300, 343], "min_freq": 301, "0001": 301, "max_freq": 301, "cos_first": 301, "full_turn": 301, "sinusoid": 301, "sin": [301, 381, 385], "lambd": [302, 346], "threshold": [303, 328, 335, 347], "geq": [303, 347], "num_encoder_lay": 304, "num_decoder_lay": 304, "nb_func": 304, "custom_encod": 304, "custom_decod": 304, "norm_first": 304, "checkpoint": 304, "decod": 304, "interact": 304, "mechan": 304, "chekpoint": 304, "usag": [304, 380], "expens": 304, "scale_factor": 305, "nearest": 305, "align_corn": 305, "audio": 305, "4d": 305, "forth": 305, "neighbor": 305, "interpol": 305, "bilinear": 305, "trilinear": 305, "corner": 305, "bottom": 305, "squeez": [305, 380], "75": 305, "33333": 305, "66667": 305, "init_fn": [306, 307, 308, 309, 310, 311, 312, 313, 350], "glorot": [307, 308], "fan_in": [307, 308, 309, 310], "fan_out": [307, 308, 309, 310], "fan": [307, 308, 309, 310], "_in": [307, 308], "_out": [307, 308], "difficulti": [307, 308], "feedforward": [307, 308], "191107": 307, "61278": 307, "150594": 307, "363207": 307, "gain": [307, 308, 309, 310], "89613": 307, "53947": 307, "48095": 307, "995016": 307, "223404": 308, "890597": 308, "379159": 308, "776856": 308, "90041": 308, "02264": 308, "912766": 308, "12451": 308, "delv": [309, 310], "surpass": [309, 310], "human": [309, 310], "level": [309, 310], "imagenet": [309, 310], "classif": [309, 310], "25211": 309, "458835": 309, "177208": 309, "0137595": 309, "6967": 309, "02765": 309, "15268": 309, "75787": 309, "kaim": 310, "0300242": 310, "0184009": 310, "793615": 310, "666329": 310, "64331": 310, "16506": 310, "08619": 310, "79854": 310, "982273": 312, "534422": 312, "380709": 312, "0645099": 312, "883935": 313, "863726": 313, "617261": 313, "417497": 313, "exact": [316, 317], "0003": 316, "cdot": [316, 317, 324, 327, 343], "015": 317, "702": 317, "hendryck": 317, "1606": 317, "08415": 317, "halv": 318, "negative_slop": 320, "leaki": 320, "sum_i": 322, "x_i": [322, 344], "with_logit": 323, "predict": [323, 326, 327, 328, 329, 330, 331, 333, 334, 335], "105361": 323, "223144": 323, "20397": 323, "916291": 323, "539245": 323, "prob": 323, "510826": 323, "x1": 324, "x2": 324, "x_1": [324, 332], "x_2": [324, 332], "label_smooth": 325, "hot": 325, "smooth": [325, 335, 367], "0485873": 325, "348587": 325, "06": [326, 336, 356], "likelihood": [326, 334], "nll": [326, 334], "hing": 327, "y_": [327, 331], "pred": [327, 331], "delta": [328, 356], "huber": 328, "leq": [328, 341], "l2": [328, 331, 368], "kullback": 329, "leibler": 329, "diverg": 329, "cosh": 331, "logcosh": 331, "sensit": 331, "outlier": 331, "dual": 331, "behavior": [331, 382, 383], "offer": 331, "balanc": 331, "robust": 331, "approach": [331, 381], "inputs1": 332, "inputs2": 332, "margin": [332, 336], "rank": 332, "573409": 332, "765166": 332, "0638": 332, "75596": 332, "225763": 332, "256995": 332, "773433": 332, "formula": 335, "anchor": 336, "triplet": 336, "_p": 336, "degre": 336, "pairwis": 336, "instabl": 336, "monoton": 337, "0507": 341, "67326": 341, "sum_j": 344, "x_j": 344, "subclass": 353, "concept": 353, "mymlp": 353, "in_proj": 353, "subsequ": 355, "apply_gradi": 355, "rmsprop": 355, "adagrad": 355, "adafactor": 355, "adadelta": 355, "adam": [355, 361, 362, 371, 372], "adamw": [355, 362], "adamax": 355, "lion": 355, "cosine_decai": [355, 371], "exponential_decai": 355, "join_schedul": 355, "linear_schedul": [355, 371], "step_decai": 355, "rho": 356, "zeiler": 356, "2012": [356, 367], "adapt": [356, 357, 358], "1212": 356, "5701": 356, "v_": [356, 358, 359, 360, 361, 367, 368], "v_t": [356, 358, 359, 360, 361, 367, 368], "u_t": 356, "u_": 356, "w_t": [356, 358, 359, 360, 361, 362, 367, 368], "001": 357, "clip_threshold": 357, "decay_r": [357, 370, 373], "beta_1": [357, 359, 360, 361, 362], "weight_decai": [357, 360, 362, 368], "scale_paramet": 357, "relative_step": 357, "warmup_init": 357, "sublinear": 357, "cost": [357, 383], "epsilon_1": 357, "epsilon_2": 357, "parameter_scal": 357, "clip": 357, "unscal": 357, "decai": [357, 360, 362, 368, 369, 370, 373], "duchi": 358, "hazan": 358, "singer": 358, "2011": 358, "subgradi": 358, "onlin": 358, "stochast": [358, 359, 361, 368, 383], "jmlr": 358, "999": [359, 360, 361], "omit": [359, 361], "estim": [359, 361], "kingma": [359, 361], "ba": [359, 361], "iclr": [359, 360, 361], "m_": [359, 360, 361, 362], "m_t": [359, 360, 361, 362], "beta_2": [359, 360, 361, 362], "contrast": 360, "loshchilov": 360, "hutter": 360, "decoupl": 360, "99": [362, 367], "tend": 362, "10x": 362, "strength": [362, 368], "wd": 362, "chen": 362, "symbol": 362, "discoveri": 362, "2302": 362, "06675": 362, "eta": 362, "opt": 363, "tieleman": 367, "hinton": 367, "lectur": 367, "coursera": 367, "dampen": 368, "nesterov": 368, "descent": [368, 380, 383], "mu": 368, "tau": 368, "penalti": 368, "decay_step": 369, "beyond": [369, 372], "minim": 369, "lr_schedul": [369, 370, 371, 373], "1000": [369, 380], "0999961": 369, "06561": 370, "boundari": 371, "join": 371, "receiv": [371, 384], "transit": 371, "warmup": [371, 372], "0999938": 371, "101": 372, "step_siz": 373, "081": 373, "basi": 375, "implicit": [377, 380, 381], "fine": [377, 383], "grain": 377, "control": [377, 383], "pseudo": 377, "altern": 377, "splittabl": 377, "threefri": 377, "counter": 377, "cycl": 379, "merg": 380, "fuse": 380, "big": 380, "awar": [380, 383], "36788": 380, "compiled_fun": 380, "code": [380, 383], "slow": 380, "stack": 380, "rerun": [380, 383], "frequent": [380, 383], "destroi": 380, "anonym": 380, "don": [380, 387], "unari": 380, "overhead": [380, 383, 387], "bandwidth": 380, "fusibl": 380, "consider": 380, "versu": 380, "timeit": [380, 381], "tic": 380, "perf_count": 380, "toc": 380, "tpi": 380, "1e3": 380, "4096": [380, 381, 387], "On": [380, 381, 383], "millisecond": [380, 387], "five": 380, "latest": 380, "won": 380, "placehold": 380, "insid": 380, "crash": 380, "disable_compil": 380, "okai": [380, 383], "intend": 380, "deal": 380, "pretti": [380, 383], "inconveni": 380, "functool": 380, "particularli": 380, "backward": [380, 381], "compiled_grad_fn": 380, "71828": 380, "outer": [380, 383], "opportun": 380, "idea": [381, 383], "behind": 381, "dfdx": [381, 382], "d2fdx2": 381, "zero_grad": 381, "detach": 381, "requires_grad": 381, "dloss_dw": 381, "dloss_dx": 381, "lot": 381, "redund": 381, "suppos": [381, 387], "nice": [381, 383], "propag": [381, 382], "stop_gradi": 381, "autom": 381, "contriv": [381, 387], "sake": 381, "clariti": 381, "quit": [381, 384], "power": [381, 384], "difficult": 381, "primit": 381, "issu": [381, 384], "priorit": 381, "naive_add": 381, "vmap_add": 381, "total": 381, "390": 381, "wherea": 381, "025": 381, "ten": [381, 383], "Of": 381, "better": [381, 387], "handi": 381, "slice": 382, "ellipsi": 382, "mix": 382, "take_along_axi": 382, "lack": 382, "extrem": [382, 383], "ineffici": [382, 383], "nonzero": 382, "dynam": 383, "easier": 383, "worri": 383, "fun1": 383, "expensive_fun": 383, "consum": 383, "eager": 383, "thank": 383, "weights_fp16": 383, "trade": 383, "bad": 383, "grow": 383, "computation": 383, "costli": 383, "wide": 383, "thousand": 383, "value_and_grad_fn": 383, "implicitli": 383, "anytim": 383, "memoryview": [383, 384], "perfectli": 383, "first_lay": 383, "second_layer_a": 383, "second_layer_b": 383, "protocol": 384, "pep": 384, "3118": 384, "a_view": 384, "owndata": 384, "extern": 384, "x_view": 384, "modifi": 384, "df": 384, "x\u00b2": 384, "2x": 384, "indirectli": 384, "modif": 384, "seen": 384, "occur": 384, "incorpor": 384, "incorrect": 384, "experiment": 384, "break": 384, "advis": 384, "intermedi": 384, "jnp": 384, "tf": 384, "page": 385, "composit": 385, "archiv": 386, "savez_compress": 386, "save_gguf": 386, "arr_0": 386, "advantag": 387, "parallel": 387, "race": 387, "interest": 387, "albeit": 387, "d1": 387, "d2": 387, "matmul": 387, "dens": 387, "twice": 387, "measur": 387, "default_stream": 388, "default_devic": 388, "my_devic": 388}, "objects": {"mlx.core": [[8, 0, 1, "", "Device"], [9, 0, 1, "", "Dtype"], [10, 0, 1, "", "DtypeCategory"], [242, 0, 1, "", "Stream"], [11, 2, 1, "", "abs"], [12, 2, 1, "", "add"], [13, 2, 1, "", "all"], [14, 2, 1, "", "allclose"], [15, 2, 1, "", "any"], [16, 2, 1, "", "arange"], [17, 2, 1, "", "arccos"], [18, 2, 1, "", "arccosh"], [19, 2, 1, "", "arcsin"], [20, 2, 1, "", "arcsinh"], [21, 2, 1, "", "arctan"], [22, 2, 1, "", "arctanh"], [23, 2, 1, "", "argmax"], [24, 2, 1, "", "argmin"], [25, 2, 1, "", "argpartition"], [26, 2, 1, "", "argsort"], [27, 0, 1, "", "array"], [76, 2, 1, "", "array_equal"], [77, 2, 1, "", "atleast_1d"], [78, 2, 1, "", "atleast_2d"], [79, 2, 1, "", "atleast_3d"], [80, 2, 1, "", "broadcast_to"], [81, 2, 1, "", "ceil"], [82, 2, 1, "", "clip"], [83, 2, 1, "", "compile"], [84, 2, 1, "", "concatenate"], [85, 2, 1, "", "conv1d"], [86, 2, 1, "", "conv2d"], [87, 2, 1, "", "conv_general"], [88, 2, 1, "", "convolve"], [89, 2, 1, "", "cos"], [90, 2, 1, "", "cosh"], [91, 2, 1, "", "cummax"], [92, 2, 1, "", "cummin"], [93, 2, 1, "", "cumprod"], [94, 2, 1, "", "cumsum"], [95, 2, 1, "", "default_device"], [96, 2, 1, "", "default_stream"], [97, 2, 1, "", "dequantize"], [98, 2, 1, "", "diag"], [99, 2, 1, "", "diagonal"], [100, 2, 1, "", "disable_compile"], [101, 2, 1, "", "divide"], [102, 2, 1, "", "divmod"], [103, 2, 1, "", "enable_compile"], [104, 2, 1, "", "equal"], [105, 2, 1, "", "erf"], [106, 2, 1, "", "erfinv"], [107, 2, 1, "", "eval"], [108, 2, 1, "", "exp"], [109, 2, 1, "", "expand_dims"], [110, 2, 1, "", "eye"], [127, 2, 1, "", "flatten"], [128, 2, 1, "", "floor"], [129, 2, 1, "", "floor_divide"], [130, 2, 1, "", "full"], [131, 2, 1, "", "grad"], [132, 2, 1, "", "greater"], [133, 2, 1, "", "greater_equal"], [134, 2, 1, "", "identity"], [135, 2, 1, "", "inner"], [136, 2, 1, "", "isclose"], [137, 2, 1, "", "isinf"], [138, 2, 1, "", "isnan"], [139, 2, 1, "", "isneginf"], [140, 2, 1, "", "isposinf"], [141, 2, 1, "", "issubdtype"], [142, 2, 1, "", "jvp"], [143, 2, 1, "", "less"], [144, 2, 1, "", "less_equal"], [147, 2, 1, "", "linspace"], [148, 2, 1, "", "load"], [149, 2, 1, "", "log"], [150, 2, 1, "", "log10"], [151, 2, 1, "", "log1p"], [152, 2, 1, "", "log2"], [153, 2, 1, "", "logaddexp"], [154, 2, 1, "", "logical_and"], [155, 2, 1, "", "logical_not"], [156, 2, 1, "", "logical_or"], [157, 2, 1, "", "logsumexp"], [158, 2, 1, "", "matmul"], [159, 2, 1, "", "max"], [160, 2, 1, "", "maximum"], [161, 2, 1, "", "mean"], [168, 2, 1, "", "min"], [169, 2, 1, "", "minimum"], [170, 2, 1, "", "moveaxis"], [171, 2, 1, "", "multiply"], [172, 2, 1, "", "negative"], [173, 2, 1, "", "new_stream"], [174, 2, 1, "", "ones"], [175, 2, 1, "", "ones_like"], [176, 2, 1, "", "outer"], [177, 2, 1, "", "pad"], [178, 2, 1, "", "partition"], [179, 2, 1, "", "prod"], [180, 2, 1, "", "quantize"], [181, 2, 1, "", "quantized_matmul"], [192, 2, 1, "", "reciprocal"], [193, 2, 1, "", "repeat"], [194, 2, 1, "", "reshape"], [195, 2, 1, "", "round"], [196, 2, 1, "", "rsqrt"], [197, 2, 1, "", "save"], [198, 2, 1, "", "save_gguf"], [199, 2, 1, "", "save_safetensors"], [200, 2, 1, "", "savez"], [201, 2, 1, "", "savez_compressed"], [202, 2, 1, "", "set_default_device"], [203, 2, 1, "", "set_default_stream"], [204, 2, 1, "", "sigmoid"], [205, 2, 1, "", "sign"], [206, 2, 1, "", "sin"], [207, 2, 1, "", "sinh"], [208, 2, 1, "", "softmax"], [209, 2, 1, "", "sort"], [210, 2, 1, "", "split"], [211, 2, 1, "", "sqrt"], [212, 2, 1, "", "square"], [213, 2, 1, "", "squeeze"], [214, 2, 1, "", "stack"], [215, 2, 1, "", "stop_gradient"], [216, 2, 1, "", "stream"], [217, 2, 1, "", "subtract"], [218, 2, 1, "", "sum"], [219, 2, 1, "", "swapaxes"], [220, 2, 1, "", "take"], [221, 2, 1, "", "take_along_axis"], [222, 2, 1, "", "tan"], [223, 2, 1, "", "tanh"], [224, 2, 1, "", "tensordot"], [225, 2, 1, "", "tile"], [226, 2, 1, "", "topk"], [227, 2, 1, "", "transpose"], [228, 2, 1, "", "tri"], [229, 2, 1, "", "tril"], [230, 2, 1, "", "triu"], [231, 2, 1, "", "value_and_grad"], [232, 2, 1, "", "var"], [233, 2, 1, "", "vjp"], [234, 2, 1, "", "vmap"], [235, 2, 1, "", "where"], [236, 2, 1, "", "zeros"], [237, 2, 1, "", "zeros_like"]], "mlx.core.Device": [[8, 1, 1, "", "__init__"]], "mlx.core.Dtype": [[9, 1, 1, "", "__init__"]], "mlx.core.DtypeCategory": [[10, 1, 1, "", "__init__"]], "mlx.core.Stream": [[242, 1, 1, "", "__init__"]], "mlx.core.array": [[28, 3, 1, "", "T"], [27, 1, 1, "", "__init__"], [29, 1, 1, "", "abs"], [30, 1, 1, "", "all"], [31, 1, 1, "", "any"], [32, 1, 1, "", "argmax"], [33, 1, 1, "", "argmin"], [34, 1, 1, "", "astype"], [35, 3, 1, "", "at"], [36, 1, 1, "", "cos"], [37, 1, 1, "", "cummax"], [38, 1, 1, "", "cummin"], [39, 1, 1, "", "cumprod"], [40, 1, 1, "", "cumsum"], [41, 1, 1, "", "diag"], [42, 1, 1, "", "diagonal"], [43, 3, 1, "", "dtype"], [44, 1, 1, "", "exp"], [45, 1, 1, "", "flatten"], [46, 1, 1, "", "item"], [47, 3, 1, "", "itemsize"], [48, 1, 1, "", "log"], [49, 1, 1, "", "log10"], [50, 1, 1, "", "log1p"], [51, 1, 1, "", "log2"], [52, 1, 1, "", "logsumexp"], [53, 1, 1, "", "max"], [54, 1, 1, "", "mean"], [55, 1, 1, "", "min"], [56, 1, 1, "", "moveaxis"], [57, 3, 1, "", "nbytes"], [58, 3, 1, "", "ndim"], [59, 1, 1, "", "prod"], [60, 1, 1, "", "reciprocal"], [61, 1, 1, "", "reshape"], [62, 1, 1, "", "round"], [63, 1, 1, "", "rsqrt"], [64, 3, 1, "", "shape"], [65, 1, 1, "", "sin"], [66, 3, 1, "", "size"], [67, 1, 1, "", "split"], [68, 1, 1, "", "sqrt"], [69, 1, 1, "", "square"], [70, 1, 1, "", "squeeze"], [71, 1, 1, "", "sum"], [72, 1, 1, "", "swapaxes"], [73, 1, 1, "", "tolist"], [74, 1, 1, "", "transpose"], [75, 1, 1, "", "var"]], "mlx.core.fast": [[111, 2, 1, "", "layer_norm"], [112, 2, 1, "", "rms_norm"], [113, 2, 1, "", "rope"], [114, 2, 1, "", "scaled_dot_product_attention"]], "mlx.core.fft": [[115, 2, 1, "", "fft"], [116, 2, 1, "", "fft2"], [117, 2, 1, "", "fftn"], [118, 2, 1, "", "ifft"], [119, 2, 1, "", "ifft2"], [120, 2, 1, "", "ifftn"], [121, 2, 1, "", "irfft"], [122, 2, 1, "", "irfft2"], [123, 2, 1, "", "irfftn"], [124, 2, 1, "", "rfft"], [125, 2, 1, "", "rfft2"], [126, 2, 1, "", "rfftn"]], "mlx.core.linalg": [[145, 2, 1, "", "norm"], [146, 2, 1, "", "qr"]], "mlx.core.metal": [[162, 2, 1, "", "get_active_memory"], [163, 2, 1, "", "get_cache_memory"], [164, 2, 1, "", "get_peak_memory"], [165, 2, 1, "", "is_available"], [166, 2, 1, "", "set_cache_limit"], [167, 2, 1, "", "set_memory_limit"]], "mlx.core.random": [[182, 2, 1, "", "bernoulli"], [183, 2, 1, "", "categorical"], [184, 2, 1, "", "gumbel"], [185, 2, 1, "", "key"], [186, 2, 1, "", "normal"], [187, 2, 1, "", "randint"], [188, 2, 1, "", "seed"], [189, 2, 1, "", "split"], [190, 2, 1, "", "truncated_normal"], [191, 2, 1, "", "uniform"]], "mlx.nn": [[251, 0, 1, "", "ALiBi"], [252, 0, 1, "", "AvgPool1d"], [253, 0, 1, "", "AvgPool2d"], [254, 0, 1, "", "BatchNorm"], [255, 0, 1, "", "Conv1d"], [256, 0, 1, "", "Conv2d"], [257, 0, 1, "", "Dropout"], [258, 0, 1, "", "Dropout2d"], [259, 0, 1, "", "Dropout3d"], [260, 0, 1, "", "Embedding"], [261, 0, 1, "", "GELU"], [262, 0, 1, "", "GRU"], [263, 0, 1, "", "GroupNorm"], [264, 0, 1, "", "InstanceNorm"], [265, 0, 1, "", "LSTM"], [266, 0, 1, "", "LayerNorm"], [267, 0, 1, "", "Linear"], [268, 0, 1, "", "MaxPool1d"], [269, 0, 1, "", "MaxPool2d"], [270, 0, 1, "", "Mish"], [353, 0, 1, "", "Module"], [291, 0, 1, "", "MultiHeadAttention"], [292, 0, 1, "", "PReLU"], [293, 0, 1, "", "QuantizedLinear"], [294, 0, 1, "", "RMSNorm"], [295, 0, 1, "", "RNN"], [296, 0, 1, "", "ReLU"], [297, 0, 1, "", "RoPE"], [298, 0, 1, "", "SELU"], [299, 0, 1, "", "Sequential"], [300, 0, 1, "", "SiLU"], [301, 0, 1, "", "SinusoidalPositionalEncoding"], [302, 0, 1, "", "Softshrink"], [303, 0, 1, "", "Step"], [304, 0, 1, "", "Transformer"], [305, 0, 1, "", "Upsample"], [314, 2, 1, "", "elu"], [315, 2, 1, "", "gelu"], [316, 2, 1, "", "gelu_approx"], [317, 2, 1, "", "gelu_fast_approx"], [318, 2, 1, "", "glu"], [319, 2, 1, "", "hardswish"], [320, 2, 1, "", "leaky_relu"], [321, 2, 1, "", "log_sigmoid"], [322, 2, 1, "", "log_softmax"], [337, 2, 1, "", "mish"], [338, 2, 1, "", "prelu"], [339, 2, 1, "", "relu"], [340, 2, 1, "", "relu6"], [341, 2, 1, "", "selu"], [342, 2, 1, "", "sigmoid"], [343, 2, 1, "", "silu"], [344, 2, 1, "", "softmax"], [345, 2, 1, "", "softplus"], [346, 2, 1, "", "softshrink"], [347, 2, 1, "", "step"], [348, 2, 1, "", "tanh"], [238, 2, 1, "", "value_and_grad"]], "mlx.nn.Module": [[271, 1, 1, "", "apply"], [272, 1, 1, "", "apply_to_modules"], [273, 1, 1, "", "children"], [274, 1, 1, "", "eval"], [275, 1, 1, "", "filter_and_map"], [276, 1, 1, "", "freeze"], [277, 1, 1, "", "leaf_modules"], [278, 1, 1, "", "load_weights"], [279, 1, 1, "", "modules"], [280, 1, 1, "", "named_modules"], [281, 1, 1, "", "parameters"], [282, 1, 1, "", "save_weights"], [283, 1, 1, "", "set_dtype"], [284, 3, 1, "", "state"], [285, 1, 1, "", "train"], [286, 1, 1, "", "trainable_parameters"], [287, 3, 1, "", "training"], [288, 1, 1, "", "unfreeze"], [289, 1, 1, "", "update"], [290, 1, 1, "", "update_modules"]], "mlx.nn.init": [[306, 2, 1, "", "constant"], [307, 2, 1, "", "glorot_normal"], [308, 2, 1, "", "glorot_uniform"], [309, 2, 1, "", "he_normal"], [310, 2, 1, "", "he_uniform"], [311, 2, 1, "", "identity"], [312, 2, 1, "", "normal"], [313, 2, 1, "", "uniform"]], "mlx.nn.losses": [[323, 2, 1, "", "binary_cross_entropy"], [324, 2, 1, "", "cosine_similarity_loss"], [325, 2, 1, "", "cross_entropy"], [326, 2, 1, "", "gaussian_nll_loss"], [327, 2, 1, "", "hinge_loss"], [328, 2, 1, "", "huber_loss"], [329, 2, 1, "", "kl_div_loss"], [330, 2, 1, "", "l1_loss"], [331, 2, 1, "", "log_cosh_loss"], [332, 2, 1, "", "margin_ranking_loss"], [333, 2, 1, "", "mse_loss"], [334, 2, 1, "", "nll_loss"], [335, 2, 1, "", "smooth_l1_loss"], [336, 2, 1, "", "triplet_loss"]], "mlx.optimizers": [[356, 0, 1, "", "AdaDelta"], [357, 0, 1, "", "Adafactor"], [358, 0, 1, "", "Adagrad"], [359, 0, 1, "", "Adam"], [360, 0, 1, "", "AdamW"], [361, 0, 1, "", "Adamax"], [362, 0, 1, "", "Lion"], [375, 0, 1, "", "Optimizer"], [367, 0, 1, "", "RMSprop"], [368, 0, 1, "", "SGD"], [369, 2, 1, "", "cosine_decay"], [370, 2, 1, "", "exponential_decay"], [371, 2, 1, "", "join_schedules"], [372, 2, 1, "", "linear_schedule"], [373, 2, 1, "", "step_decay"]], "mlx.optimizers.Optimizer": [[363, 1, 1, "", "apply_gradients"], [364, 1, 1, "", "init"], [365, 3, 1, "", "state"], [366, 1, 1, "", "update"]], "mlx.utils": [[239, 2, 1, "", "tree_flatten"], [240, 2, 1, "", "tree_map"], [241, 2, 1, "", "tree_unflatten"]]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function", "3": "py:property"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"], "3": ["py", "property", "Python property"]}, "titleterms": {"oper": [0, 1, 354], "develop": 1, "document": 1, "introduc": 1, "exampl": [1, 6, 380, 387], "primit": 1, "us": [1, 383, 388], "implement": [1, 4], "cpu": 1, "backend": 1, "gpu": 1, "transform": [1, 304, 378, 380, 381, 383, 385], "build": [1, 7], "bind": 1, "python": [1, 6, 7], "cmake": 1, "setuptool": 1, "usag": [1, 6], "result": 1, "script": [1, 4], "download": [1, 4], "code": [1, 4], "metal": [2, 7, 162, 163, 164, 165, 166, 167, 249], "debugg": 2, "xcode": 2, "workflow": 2, "linear": [3, 248, 267], "regress": 3, "llm": 4, "infer": 4, "model": 4, "attent": 4, "layer": [4, 5, 351], "encod": 4, "full": [4, 130], "gener": 4, "put": 4, "all": [4, 13, 30], "togeth": 4, "convert": 4, "weight": 4, "load": [4, 148, 386], "benchmark": 4, "multi": 5, "perceptron": 5, "mlx": [6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373], "instal": [6, 7], "api": [6, 7], "refer": 6, "c": [6, 7], "further": 6, "read": 6, "troubleshoot": 7, "from": [7, 382], "sourc": 7, "requir": 7, "option": 7, "found": 7, "x86": 7, "shell": 7, "core": [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 242], "devic": [8, 245], "dtype": [9, 43], "dtypecategori": 10, "ab": [11, 29], "add": 12, "allclos": 14, "ani": [15, 31], "arang": 16, "arcco": 17, "arccosh": 18, "arcsin": 19, "arcsinh": 20, "arctan": 21, "arctanh": 22, "argmax": [23, 32], "argmin": [24, 33], "argpartit": 25, "argsort": 26, "arrai": [27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 243, 382, 386], "t": 28, "astyp": 34, "co": [36, 89], "cummax": [37, 91], "cummin": [38, 92], "cumprod": [39, 93], "cumsum": [40, 94], "diag": [41, 98], "diagon": [42, 99], "exp": [44, 108], "flatten": [45, 127], "item": 46, "items": 47, "log": [48, 149], "log10": [49, 150], "log1p": [50, 151], "log2": [51, 152], "logsumexp": [52, 157], "max": [53, 159], "mean": [54, 161], "min": [55, 168], "moveaxi": [56, 170], "nbyte": 57, "ndim": 58, "prod": [59, 179], "reciproc": [60, 192], "reshap": [61, 194], "round": [62, 195], "rsqrt": [63, 196], "shape": 64, "sin": [65, 206], "size": 66, "split": [67, 189, 210], "sqrt": [68, 211], "squar": [69, 212], "squeez": [70, 213], "sum": [71, 218], "swapax": [72, 219], "tolist": 73, "transpos": [74, 227], "var": [75, 232], "array_equ": 76, "atleast_1d": 77, "atleast_2d": 78, "atleast_3d": 79, "broadcast_to": 80, "ceil": 81, "clip": 82, "compil": [83, 380], "concaten": 84, "conv1d": [85, 255], "conv2d": [86, 256], "conv_gener": 87, "convolv": 88, "cosh": 90, "default_devic": 95, "default_stream": 96, "dequant": 97, "disable_compil": 100, "divid": 101, "divmod": 102, "enable_compil": 103, "equal": 104, "erf": 105, "erfinv": 106, "eval": [107, 274], "expand_dim": 109, "ey": 110, "fast": [111, 112, 113, 114, 246], "layer_norm": 111, "rms_norm": 112, "rope": [113, 297], "scaled_dot_product_attent": 114, "fft": [115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 247], "fft2": 116, "fftn": 117, "ifft": 118, "ifft2": 119, "ifftn": 120, "irfft": 121, "irfft2": 122, "irfftn": 123, "rfft": 124, "rfft2": 125, "rfftn": 126, "floor": 128, "floor_divid": 129, "grad": [131, 250], "greater": 132, "greater_equ": 133, "ident": [134, 311], "inner": 135, "isclos": 136, "isinf": 137, "isnan": 138, "isneginf": 139, "isposinf": 140, "issubdtyp": 141, "jvp": 142, "less": 143, "less_equ": 144, "linalg": [145, 146], "norm": 145, "qr": 146, "linspac": 147, "logaddexp": 153, "logical_and": 154, "logical_not": 155, "logical_or": 156, "matmul": 158, "maximum": 160, "get_active_memori": 162, "get_cache_memori": 163, "get_peak_memori": 164, "is_avail": 165, "set_cache_limit": 166, "set_memory_limit": 167, "minimum": 169, "multipli": 171, "neg": 172, "new_stream": 173, "ones": 174, "ones_lik": 175, "outer": 176, "pad": 177, "partit": 178, "quantiz": 180, "quantized_matmul": 181, "random": [182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 377], "bernoulli": 182, "categor": 183, "gumbel": 184, "kei": 185, "normal": [186, 312], "randint": 187, "seed": 188, "truncated_norm": 190, "uniform": [191, 313], "repeat": 193, "save": [197, 386], "save_gguf": 198, "save_safetensor": 199, "savez": 200, "savez_compress": 201, "set_default_devic": 202, "set_default_stream": 203, "sigmoid": [204, 342], "sign": 205, "sinh": 207, "softmax": [208, 344], "sort": 209, "stack": 214, "stop_gradi": 215, "stream": [216, 242, 245, 388], "subtract": 217, "take": 220, "take_along_axi": 221, "tan": 222, "tanh": [223, 348], "tensordot": 224, "tile": 225, "topk": 226, "tri": 228, "tril": 229, "triu": 230, "value_and_grad": [231, 238], "vjp": 233, "vmap": 234, "where": 235, "zero": 236, "zeros_lik": 237, "nn": [238, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348], "util": [239, 240, 241, 379], "tree_flatten": 239, "tree_map": 240, "tree_unflatten": 241, "data": 244, "type": 244, "support": 244, "algebra": 248, "neural": 250, "network": 250, "quick": [250, 385], "start": [250, 385], "The": 250, "modul": [250, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 353], "class": 250, "paramet": [250, 281], "updat": [250, 289, 366, 382], "inspect": 250, "valu": 250, "alibi": 251, "avgpool1d": 252, "avgpool2d": 253, "batchnorm": 254, "dropout": 257, "dropout2d": 258, "dropout3d": 259, "embed": 260, "gelu": [261, 315], "gru": 262, "groupnorm": 263, "instancenorm": 264, "lstm": 265, "layernorm": 266, "maxpool1d": 268, "maxpool2d": 269, "mish": [270, 337], "appli": 271, "apply_to_modul": 272, "children": 273, "filter_and_map": 275, "freez": 276, "leaf_modul": 277, "load_weight": 278, "named_modul": 280, "save_weight": 282, "set_dtyp": 283, "state": [284, 365], "train": [285, 287, 380], "trainable_paramet": 286, "unfreez": 288, "update_modul": 290, "multiheadattent": 291, "prelu": [292, 338], "quantizedlinear": 293, "rmsnorm": 294, "rnn": 295, "relu": [296, 339], "selu": [298, 341], "sequenti": 299, "silu": [300, 343], "sinusoidalpositionalencod": 301, "softshrink": [302, 346], "step": [303, 347], "upsampl": 305, "init": [306, 307, 308, 309, 310, 311, 312, 313, 364], "constant": 306, "glorot_norm": 307, "glorot_uniform": 308, "he_norm": 309, "he_uniform": 310, "elu": 314, "gelu_approx": 316, "gelu_fast_approx": 317, "glu": 318, "hardswish": 319, "leaky_relu": 320, "log_sigmoid": 321, "log_softmax": 322, "loss": [323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 352], "binary_cross_entropi": 323, "cosine_similarity_loss": 324, "cross_entropi": 325, "gaussian_nll_loss": 326, "hinge_loss": 327, "huber_loss": 328, "kl_div_loss": 329, "l1_loss": 330, "log_cosh_loss": 331, "margin_ranking_loss": 332, "mse_loss": 333, "nll_loss": 334, "smooth_l1_loss": 335, "triplet_loss": 336, "relu6": 340, "softplu": 345, "function": [349, 352, 380, 381, 385], "initi": 350, "optim": [355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375], "adadelta": 356, "adafactor": 357, "adagrad": 358, "adam": 359, "adamw": 360, "adamax": 361, "lion": 362, "apply_gradi": 363, "rmsprop": 367, "sgd": 368, "cosine_decai": 369, "exponential_decai": 370, "join_schedul": 371, "linear_schedul": 372, "step_decai": 373, "common": 374, "schedul": 376, "tree": 379, "basic": [380, 385], "speedup": 380, "debug": 380, "pure": 380, "graph": [380, 383, 385], "automat": 381, "differenti": 381, "vector": 381, "index": 382, "differ": 382, "numpi": [382, 384], "In": 382, "place": 382, "lazi": 383, "evalu": 383, "why": 383, "comput": 383, "onli": 383, "what": 383, "you": 383, "when": 383, "convers": 384, "other": 384, "framework": 384, "pytorch": 384, "jax": 384, "tensorflow": 384, "guid": 385, "serial": 386, "format": 386, "unifi": 387, "memori": 387, "A": 387, "simpl": 387, "specifi": 388}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 60}, "alltitles": {"Operations": [[0, "operations"], [1, "operations"], [354, "operations"]], "Developer Documentation": [[1, "developer-documentation"]], "Introducing the Example": [[1, "introducing-the-example"]], "Operations and Primitives": [[1, "operations-and-primitives"]], "Primitives": [[1, "primitives"]], "Using the Primitives": [[1, "using-the-primitives"]], "Implementing the Primitive": [[1, "implementing-the-primitive"]], "Implementing the CPU Backend": [[1, "implementing-the-cpu-backend"]], "Implementing the GPU Backend": [[1, "implementing-the-gpu-backend"]], "Primitive Transforms": [[1, "primitive-transforms"]], "Building and Binding": [[1, "building-and-binding"]], "Binding to Python": [[1, "binding-to-python"]], "Building with CMake": [[1, "building-with-cmake"]], "Building with setuptools": [[1, "building-with-setuptools"]], "Usage": [[1, "usage"], [6, null]], "Results": [[1, "results"]], "Scripts": [[1, "scripts"], [4, "scripts"]], "Download the code": [[1, null], [4, null]], "Metal Debugger": [[2, "metal-debugger"]], "Xcode Workflow": [[2, "xcode-workflow"]], "Linear Regression": [[3, "linear-regression"]], "LLM inference": [[4, "llm-inference"]], "Implementing the model": [[4, "implementing-the-model"]], "Attention layer": [[4, "attention-layer"]], "Encoder layer": [[4, "encoder-layer"]], "Full model": [[4, "full-model"]], "Generation": [[4, "generation"]], "Putting it all together": [[4, "putting-it-all-together"]], "Converting the weights": [[4, "converting-the-weights"]], "Weight loading and benchmarking": [[4, "weight-loading-and-benchmarking"]], "Multi-Layer Perceptron": [[5, "multi-layer-perceptron"]], "MLX": [[6, "mlx"]], "Install": [[6, null]], "Examples": [[6, null]], "Python API Reference": [[6, null]], "C++ API Reference": [[6, null]], "Further Reading": [[6, null]], "Build and Install": [[7, "build-and-install"]], "Python Installation": [[7, "python-installation"]], "Troubleshooting": [[7, "troubleshooting"], [7, "id2"]], "Build from source": [[7, "build-from-source"]], "Build Requirements": [[7, "build-requirements"]], "Python API": [[7, "python-api"]], "C++ API": [[7, "c-api"]], "Build Options": [[7, "id3"]], "Metal not found": [[7, "metal-not-found"]], "x86 Shell": [[7, "x86-shell"]], "mlx.core.Device": [[8, "mlx-core-device"]], "mlx.core.Dtype": [[9, "mlx-core-dtype"]], "mlx.core.DtypeCategory": [[10, "mlx-core-dtypecategory"]], "mlx.core.abs": [[11, "mlx-core-abs"]], "mlx.core.add": [[12, "mlx-core-add"]], "mlx.core.all": [[13, "mlx-core-all"]], "mlx.core.allclose": [[14, "mlx-core-allclose"]], "mlx.core.any": [[15, "mlx-core-any"]], "mlx.core.arange": [[16, "mlx-core-arange"]], "mlx.core.arccos": [[17, "mlx-core-arccos"]], "mlx.core.arccosh": [[18, "mlx-core-arccosh"]], "mlx.core.arcsin": [[19, "mlx-core-arcsin"]], "mlx.core.arcsinh": [[20, "mlx-core-arcsinh"]], "mlx.core.arctan": [[21, "mlx-core-arctan"]], "mlx.core.arctanh": [[22, "mlx-core-arctanh"]], "mlx.core.argmax": [[23, "mlx-core-argmax"]], "mlx.core.argmin": [[24, "mlx-core-argmin"]], "mlx.core.argpartition": [[25, "mlx-core-argpartition"]], "mlx.core.argsort": [[26, "mlx-core-argsort"]], "mlx.core.array": [[27, "mlx-core-array"]], "mlx.core.array.T": [[28, "mlx-core-array-t"]], "mlx.core.array.abs": [[29, "mlx-core-array-abs"]], "mlx.core.array.all": [[30, "mlx-core-array-all"]], "mlx.core.array.any": [[31, "mlx-core-array-any"]], "mlx.core.array.argmax": [[32, "mlx-core-array-argmax"]], "mlx.core.array.argmin": [[33, "mlx-core-array-argmin"]], "mlx.core.array.astype": [[34, "mlx-core-array-astype"]], "mlx.core.array.at": [[35, "mlx-core-array-at"]], "mlx.core.array.cos": [[36, "mlx-core-array-cos"]], "mlx.core.array.cummax": [[37, "mlx-core-array-cummax"]], "mlx.core.array.cummin": [[38, "mlx-core-array-cummin"]], "mlx.core.array.cumprod": [[39, "mlx-core-array-cumprod"]], "mlx.core.array.cumsum": [[40, "mlx-core-array-cumsum"]], "mlx.core.array.diag": [[41, "mlx-core-array-diag"]], "mlx.core.array.diagonal": [[42, "mlx-core-array-diagonal"]], "mlx.core.array.dtype": [[43, "mlx-core-array-dtype"]], "mlx.core.array.exp": [[44, "mlx-core-array-exp"]], "mlx.core.array.flatten": [[45, "mlx-core-array-flatten"]], "mlx.core.array.item": [[46, "mlx-core-array-item"]], "mlx.core.array.itemsize": [[47, "mlx-core-array-itemsize"]], "mlx.core.array.log": [[48, "mlx-core-array-log"]], "mlx.core.array.log10": [[49, "mlx-core-array-log10"]], "mlx.core.array.log1p": [[50, "mlx-core-array-log1p"]], "mlx.core.array.log2": [[51, "mlx-core-array-log2"]], "mlx.core.array.logsumexp": [[52, "mlx-core-array-logsumexp"]], "mlx.core.array.max": [[53, "mlx-core-array-max"]], "mlx.core.array.mean": [[54, "mlx-core-array-mean"]], "mlx.core.array.min": [[55, "mlx-core-array-min"]], "mlx.core.array.moveaxis": [[56, "mlx-core-array-moveaxis"]], "mlx.core.array.nbytes": [[57, "mlx-core-array-nbytes"]], "mlx.core.array.ndim": [[58, "mlx-core-array-ndim"]], "mlx.core.array.prod": [[59, "mlx-core-array-prod"]], "mlx.core.array.reciprocal": [[60, "mlx-core-array-reciprocal"]], "mlx.core.array.reshape": [[61, "mlx-core-array-reshape"]], "mlx.core.array.round": [[62, "mlx-core-array-round"]], "mlx.core.array.rsqrt": [[63, "mlx-core-array-rsqrt"]], "mlx.core.array.shape": [[64, "mlx-core-array-shape"]], "mlx.core.array.sin": [[65, "mlx-core-array-sin"]], "mlx.core.array.size": [[66, "mlx-core-array-size"]], "mlx.core.array.split": [[67, "mlx-core-array-split"]], "mlx.core.array.sqrt": [[68, "mlx-core-array-sqrt"]], "mlx.core.array.square": [[69, "mlx-core-array-square"]], "mlx.core.array.squeeze": [[70, "mlx-core-array-squeeze"]], "mlx.core.array.sum": [[71, "mlx-core-array-sum"]], "mlx.core.array.swapaxes": [[72, "mlx-core-array-swapaxes"]], "mlx.core.array.tolist": [[73, "mlx-core-array-tolist"]], "mlx.core.array.transpose": [[74, "mlx-core-array-transpose"]], "mlx.core.array.var": [[75, "mlx-core-array-var"]], "mlx.core.array_equal": [[76, "mlx-core-array-equal"]], "mlx.core.atleast_1d": [[77, "mlx-core-atleast-1d"]], "mlx.core.atleast_2d": [[78, "mlx-core-atleast-2d"]], "mlx.core.atleast_3d": [[79, "mlx-core-atleast-3d"]], "mlx.core.broadcast_to": [[80, "mlx-core-broadcast-to"]], "mlx.core.ceil": [[81, "mlx-core-ceil"]], "mlx.core.clip": [[82, "mlx-core-clip"]], "mlx.core.compile": [[83, "mlx-core-compile"]], "mlx.core.concatenate": [[84, "mlx-core-concatenate"]], "mlx.core.conv1d": [[85, "mlx-core-conv1d"]], "mlx.core.conv2d": [[86, "mlx-core-conv2d"]], "mlx.core.conv_general": [[87, "mlx-core-conv-general"]], "mlx.core.convolve": [[88, "mlx-core-convolve"]], "mlx.core.cos": [[89, "mlx-core-cos"]], "mlx.core.cosh": [[90, "mlx-core-cosh"]], "mlx.core.cummax": [[91, "mlx-core-cummax"]], "mlx.core.cummin": [[92, "mlx-core-cummin"]], "mlx.core.cumprod": [[93, "mlx-core-cumprod"]], "mlx.core.cumsum": [[94, "mlx-core-cumsum"]], "mlx.core.default_device": [[95, "mlx-core-default-device"]], "mlx.core.default_stream": [[96, "mlx-core-default-stream"]], "mlx.core.dequantize": [[97, "mlx-core-dequantize"]], "mlx.core.diag": [[98, "mlx-core-diag"]], "mlx.core.diagonal": [[99, "mlx-core-diagonal"]], "mlx.core.disable_compile": [[100, "mlx-core-disable-compile"]], "mlx.core.divide": [[101, "mlx-core-divide"]], "mlx.core.divmod": [[102, "mlx-core-divmod"]], "mlx.core.enable_compile": [[103, "mlx-core-enable-compile"]], "mlx.core.equal": [[104, "mlx-core-equal"]], "mlx.core.erf": [[105, "mlx-core-erf"]], "mlx.core.erfinv": [[106, "mlx-core-erfinv"]], "mlx.core.eval": [[107, "mlx-core-eval"]], "mlx.core.exp": [[108, "mlx-core-exp"]], "mlx.core.expand_dims": [[109, "mlx-core-expand-dims"]], "mlx.core.eye": [[110, "mlx-core-eye"]], "mlx.core.fast.layer_norm": [[111, "mlx-core-fast-layer-norm"]], "mlx.core.fast.rms_norm": [[112, "mlx-core-fast-rms-norm"]], "mlx.core.fast.rope": [[113, "mlx-core-fast-rope"]], "mlx.core.fast.scaled_dot_product_attention": [[114, "mlx-core-fast-scaled-dot-product-attention"]], "mlx.core.fft.fft": [[115, "mlx-core-fft-fft"]], "mlx.core.fft.fft2": [[116, "mlx-core-fft-fft2"]], "mlx.core.fft.fftn": [[117, "mlx-core-fft-fftn"]], "mlx.core.fft.ifft": [[118, "mlx-core-fft-ifft"]], "mlx.core.fft.ifft2": [[119, "mlx-core-fft-ifft2"]], "mlx.core.fft.ifftn": [[120, "mlx-core-fft-ifftn"]], "mlx.core.fft.irfft": [[121, "mlx-core-fft-irfft"]], "mlx.core.fft.irfft2": [[122, "mlx-core-fft-irfft2"]], "mlx.core.fft.irfftn": [[123, "mlx-core-fft-irfftn"]], "mlx.core.fft.rfft": [[124, "mlx-core-fft-rfft"]], "mlx.core.fft.rfft2": [[125, "mlx-core-fft-rfft2"]], "mlx.core.fft.rfftn": [[126, "mlx-core-fft-rfftn"]], "mlx.core.flatten": [[127, "mlx-core-flatten"]], "mlx.core.floor": [[128, "mlx-core-floor"]], "mlx.core.floor_divide": [[129, "mlx-core-floor-divide"]], "mlx.core.full": [[130, "mlx-core-full"]], "mlx.core.grad": [[131, "mlx-core-grad"]], "mlx.core.greater": [[132, "mlx-core-greater"]], "mlx.core.greater_equal": [[133, "mlx-core-greater-equal"]], "mlx.core.identity": [[134, "mlx-core-identity"]], "mlx.core.inner": [[135, "mlx-core-inner"]], "mlx.core.isclose": [[136, "mlx-core-isclose"]], "mlx.core.isinf": [[137, "mlx-core-isinf"]], "mlx.core.isnan": [[138, "mlx-core-isnan"]], "mlx.core.isneginf": [[139, "mlx-core-isneginf"]], "mlx.core.isposinf": [[140, "mlx-core-isposinf"]], "mlx.core.issubdtype": [[141, "mlx-core-issubdtype"]], "mlx.core.jvp": [[142, "mlx-core-jvp"]], "mlx.core.less": [[143, "mlx-core-less"]], "mlx.core.less_equal": [[144, "mlx-core-less-equal"]], "mlx.core.linalg.norm": [[145, "mlx-core-linalg-norm"]], "mlx.core.linalg.qr": [[146, "mlx-core-linalg-qr"]], "mlx.core.linspace": [[147, "mlx-core-linspace"]], "mlx.core.load": [[148, "mlx-core-load"]], "mlx.core.log": [[149, "mlx-core-log"]], "mlx.core.log10": [[150, "mlx-core-log10"]], "mlx.core.log1p": [[151, "mlx-core-log1p"]], "mlx.core.log2": [[152, "mlx-core-log2"]], "mlx.core.logaddexp": [[153, "mlx-core-logaddexp"]], "mlx.core.logical_and": [[154, "mlx-core-logical-and"]], "mlx.core.logical_not": [[155, "mlx-core-logical-not"]], "mlx.core.logical_or": [[156, "mlx-core-logical-or"]], "mlx.core.logsumexp": [[157, "mlx-core-logsumexp"]], "mlx.core.matmul": [[158, "mlx-core-matmul"]], "mlx.core.max": [[159, "mlx-core-max"]], "mlx.core.maximum": [[160, "mlx-core-maximum"]], "mlx.core.mean": [[161, "mlx-core-mean"]], "mlx.core.metal.get_active_memory": [[162, "mlx-core-metal-get-active-memory"]], "mlx.core.metal.get_cache_memory": [[163, "mlx-core-metal-get-cache-memory"]], "mlx.core.metal.get_peak_memory": [[164, "mlx-core-metal-get-peak-memory"]], "mlx.core.metal.is_available": [[165, "mlx-core-metal-is-available"]], "mlx.core.metal.set_cache_limit": [[166, "mlx-core-metal-set-cache-limit"]], "mlx.core.metal.set_memory_limit": [[167, "mlx-core-metal-set-memory-limit"]], "mlx.core.min": [[168, "mlx-core-min"]], "mlx.core.minimum": [[169, "mlx-core-minimum"]], "mlx.core.moveaxis": [[170, "mlx-core-moveaxis"]], "mlx.core.multiply": [[171, "mlx-core-multiply"]], "mlx.core.negative": [[172, "mlx-core-negative"]], "mlx.core.new_stream": [[173, "mlx-core-new-stream"]], "mlx.core.ones": [[174, "mlx-core-ones"]], "mlx.core.ones_like": [[175, "mlx-core-ones-like"]], "mlx.core.outer": [[176, "mlx-core-outer"]], "mlx.core.pad": [[177, "mlx-core-pad"]], "mlx.core.partition": [[178, "mlx-core-partition"]], "mlx.core.prod": [[179, "mlx-core-prod"]], "mlx.core.quantize": [[180, "mlx-core-quantize"]], "mlx.core.quantized_matmul": [[181, "mlx-core-quantized-matmul"]], "mlx.core.random.bernoulli": [[182, "mlx-core-random-bernoulli"]], "mlx.core.random.categorical": [[183, "mlx-core-random-categorical"]], "mlx.core.random.gumbel": [[184, "mlx-core-random-gumbel"]], "mlx.core.random.key": [[185, "mlx-core-random-key"]], "mlx.core.random.normal": [[186, "mlx-core-random-normal"]], "mlx.core.random.randint": [[187, "mlx-core-random-randint"]], "mlx.core.random.seed": [[188, "mlx-core-random-seed"]], "mlx.core.random.split": [[189, "mlx-core-random-split"]], "mlx.core.random.truncated_normal": [[190, "mlx-core-random-truncated-normal"]], "mlx.core.random.uniform": [[191, "mlx-core-random-uniform"]], "mlx.core.reciprocal": [[192, "mlx-core-reciprocal"]], "mlx.core.repeat": [[193, "mlx-core-repeat"]], "mlx.core.reshape": [[194, "mlx-core-reshape"]], "mlx.core.round": [[195, "mlx-core-round"]], "mlx.core.rsqrt": [[196, "mlx-core-rsqrt"]], "mlx.core.save": [[197, "mlx-core-save"]], "mlx.core.save_gguf": [[198, "mlx-core-save-gguf"]], "mlx.core.save_safetensors": [[199, "mlx-core-save-safetensors"]], "mlx.core.savez": [[200, "mlx-core-savez"]], "mlx.core.savez_compressed": [[201, "mlx-core-savez-compressed"]], "mlx.core.set_default_device": [[202, "mlx-core-set-default-device"]], "mlx.core.set_default_stream": [[203, "mlx-core-set-default-stream"]], "mlx.core.sigmoid": [[204, "mlx-core-sigmoid"]], "mlx.core.sign": [[205, "mlx-core-sign"]], "mlx.core.sin": [[206, "mlx-core-sin"]], "mlx.core.sinh": [[207, "mlx-core-sinh"]], "mlx.core.softmax": [[208, "mlx-core-softmax"]], "mlx.core.sort": [[209, "mlx-core-sort"]], "mlx.core.split": [[210, "mlx-core-split"]], "mlx.core.sqrt": [[211, "mlx-core-sqrt"]], "mlx.core.square": [[212, "mlx-core-square"]], "mlx.core.squeeze": [[213, "mlx-core-squeeze"]], "mlx.core.stack": [[214, "mlx-core-stack"]], "mlx.core.stop_gradient": [[215, "mlx-core-stop-gradient"]], "mlx.core.stream": [[216, "mlx-core-stream"]], "mlx.core.subtract": [[217, "mlx-core-subtract"]], "mlx.core.sum": [[218, "mlx-core-sum"]], "mlx.core.swapaxes": [[219, "mlx-core-swapaxes"]], "mlx.core.take": [[220, "mlx-core-take"]], "mlx.core.take_along_axis": [[221, "mlx-core-take-along-axis"]], "mlx.core.tan": [[222, "mlx-core-tan"]], "mlx.core.tanh": [[223, "mlx-core-tanh"]], "mlx.core.tensordot": [[224, "mlx-core-tensordot"]], "mlx.core.tile": [[225, "mlx-core-tile"]], "mlx.core.topk": [[226, "mlx-core-topk"]], "mlx.core.transpose": [[227, "mlx-core-transpose"]], "mlx.core.tri": [[228, "mlx-core-tri"]], "mlx.core.tril": [[229, "mlx-core-tril"]], "mlx.core.triu": [[230, "mlx-core-triu"]], "mlx.core.value_and_grad": [[231, "mlx-core-value-and-grad"]], "mlx.core.var": [[232, "mlx-core-var"]], "mlx.core.vjp": [[233, "mlx-core-vjp"]], "mlx.core.vmap": [[234, "mlx-core-vmap"]], "mlx.core.where": [[235, "mlx-core-where"]], "mlx.core.zeros": [[236, "mlx-core-zeros"]], "mlx.core.zeros_like": [[237, "mlx-core-zeros-like"]], "mlx.nn.value_and_grad": [[238, "mlx-nn-value-and-grad"]], "mlx.utils.tree_flatten": [[239, "mlx-utils-tree-flatten"]], "mlx.utils.tree_map": [[240, "mlx-utils-tree-map"]], "mlx.utils.tree_unflatten": [[241, "mlx-utils-tree-unflatten"]], "mlx.core.Stream": [[242, "mlx-core-stream"]], "Array": [[243, "array"]], "Data Types": [[244, "data-types"]], "Supported Data Types": [[244, "id2"]], "Devices and Streams": [[245, "devices-and-streams"]], "Fast": [[246, "fast"]], "FFT": [[247, "fft"]], "Linear Algebra": [[248, "linear-algebra"]], "Metal": [[249, "metal"]], "Neural Networks": [[250, "neural-networks"]], "Quick Start with Neural Networks": [[250, "quick-start-with-neural-networks"]], "The Module Class": [[250, "the-module-class"]], "Parameters": [[250, "parameters"]], "Updating the Parameters": [[250, "updating-the-parameters"]], "Inspecting Modules": [[250, "inspecting-modules"]], "Value and Grad": [[250, "value-and-grad"]], "mlx.nn.ALiBi": [[251, "mlx-nn-alibi"]], "mlx.nn.AvgPool1d": [[252, "mlx-nn-avgpool1d"]], "mlx.nn.AvgPool2d": [[253, "mlx-nn-avgpool2d"]], "mlx.nn.BatchNorm": [[254, "mlx-nn-batchnorm"]], "mlx.nn.Conv1d": [[255, "mlx-nn-conv1d"]], "mlx.nn.Conv2d": [[256, "mlx-nn-conv2d"]], "mlx.nn.Dropout": [[257, "mlx-nn-dropout"]], "mlx.nn.Dropout2d": [[258, "mlx-nn-dropout2d"]], "mlx.nn.Dropout3d": [[259, "mlx-nn-dropout3d"]], "mlx.nn.Embedding": [[260, "mlx-nn-embedding"]], "mlx.nn.GELU": [[261, "mlx-nn-gelu"]], "mlx.nn.GRU": [[262, "mlx-nn-gru"]], "mlx.nn.GroupNorm": [[263, "mlx-nn-groupnorm"]], "mlx.nn.InstanceNorm": [[264, "mlx-nn-instancenorm"]], "mlx.nn.LSTM": [[265, "mlx-nn-lstm"]], "mlx.nn.LayerNorm": [[266, "mlx-nn-layernorm"]], "mlx.nn.Linear": [[267, "mlx-nn-linear"]], "mlx.nn.MaxPool1d": [[268, "mlx-nn-maxpool1d"]], "mlx.nn.MaxPool2d": [[269, "mlx-nn-maxpool2d"]], "mlx.nn.Mish": [[270, "mlx-nn-mish"]], "mlx.nn.Module.apply": [[271, "mlx-nn-module-apply"]], "mlx.nn.Module.apply_to_modules": [[272, "mlx-nn-module-apply-to-modules"]], "mlx.nn.Module.children": [[273, "mlx-nn-module-children"]], "mlx.nn.Module.eval": [[274, "mlx-nn-module-eval"]], "mlx.nn.Module.filter_and_map": [[275, "mlx-nn-module-filter-and-map"]], "mlx.nn.Module.freeze": [[276, "mlx-nn-module-freeze"]], "mlx.nn.Module.leaf_modules": [[277, "mlx-nn-module-leaf-modules"]], "mlx.nn.Module.load_weights": [[278, "mlx-nn-module-load-weights"]], "mlx.nn.Module.modules": [[279, "mlx-nn-module-modules"]], "mlx.nn.Module.named_modules": [[280, "mlx-nn-module-named-modules"]], "mlx.nn.Module.parameters": [[281, "mlx-nn-module-parameters"]], "mlx.nn.Module.save_weights": [[282, "mlx-nn-module-save-weights"]], "mlx.nn.Module.set_dtype": [[283, "mlx-nn-module-set-dtype"]], "mlx.nn.Module.state": [[284, "mlx-nn-module-state"]], "mlx.nn.Module.train": [[285, "mlx-nn-module-train"]], "mlx.nn.Module.trainable_parameters": [[286, "mlx-nn-module-trainable-parameters"]], "mlx.nn.Module.training": [[287, "mlx-nn-module-training"]], "mlx.nn.Module.unfreeze": [[288, "mlx-nn-module-unfreeze"]], "mlx.nn.Module.update": [[289, "mlx-nn-module-update"]], "mlx.nn.Module.update_modules": [[290, "mlx-nn-module-update-modules"]], "mlx.nn.MultiHeadAttention": [[291, "mlx-nn-multiheadattention"]], "mlx.nn.PReLU": [[292, "mlx-nn-prelu"]], "mlx.nn.QuantizedLinear": [[293, "mlx-nn-quantizedlinear"]], "mlx.nn.RMSNorm": [[294, "mlx-nn-rmsnorm"]], "mlx.nn.RNN": [[295, "mlx-nn-rnn"]], "mlx.nn.ReLU": [[296, "mlx-nn-relu"]], "mlx.nn.RoPE": [[297, "mlx-nn-rope"]], "mlx.nn.SELU": [[298, "mlx-nn-selu"]], "mlx.nn.Sequential": [[299, "mlx-nn-sequential"]], "mlx.nn.SiLU": [[300, "mlx-nn-silu"]], "mlx.nn.SinusoidalPositionalEncoding": [[301, "mlx-nn-sinusoidalpositionalencoding"]], "mlx.nn.Softshrink": [[302, "mlx-nn-softshrink"]], "mlx.nn.Step": [[303, "mlx-nn-step"]], "mlx.nn.Transformer": [[304, "mlx-nn-transformer"]], "mlx.nn.Upsample": [[305, "mlx-nn-upsample"]], "mlx.nn.init.constant": [[306, "mlx-nn-init-constant"]], "mlx.nn.init.glorot_normal": [[307, "mlx-nn-init-glorot-normal"]], "mlx.nn.init.glorot_uniform": [[308, "mlx-nn-init-glorot-uniform"]], "mlx.nn.init.he_normal": [[309, "mlx-nn-init-he-normal"]], "mlx.nn.init.he_uniform": [[310, "mlx-nn-init-he-uniform"]], "mlx.nn.init.identity": [[311, "mlx-nn-init-identity"]], "mlx.nn.init.normal": [[312, "mlx-nn-init-normal"]], "mlx.nn.init.uniform": [[313, "mlx-nn-init-uniform"]], "mlx.nn.elu": [[314, "mlx-nn-elu"]], "mlx.nn.gelu": [[315, "mlx-nn-gelu"]], "mlx.nn.gelu_approx": [[316, "mlx-nn-gelu-approx"]], "mlx.nn.gelu_fast_approx": [[317, "mlx-nn-gelu-fast-approx"]], "mlx.nn.glu": [[318, "mlx-nn-glu"]], "mlx.nn.hardswish": [[319, "mlx-nn-hardswish"]], "mlx.nn.leaky_relu": [[320, "mlx-nn-leaky-relu"]], "mlx.nn.log_sigmoid": [[321, "mlx-nn-log-sigmoid"]], "mlx.nn.log_softmax": [[322, "mlx-nn-log-softmax"]], "mlx.nn.losses.binary_cross_entropy": [[323, "mlx-nn-losses-binary-cross-entropy"]], "mlx.nn.losses.cosine_similarity_loss": [[324, "mlx-nn-losses-cosine-similarity-loss"]], "mlx.nn.losses.cross_entropy": [[325, "mlx-nn-losses-cross-entropy"]], "mlx.nn.losses.gaussian_nll_loss": [[326, "mlx-nn-losses-gaussian-nll-loss"]], "mlx.nn.losses.hinge_loss": [[327, "mlx-nn-losses-hinge-loss"]], "mlx.nn.losses.huber_loss": [[328, "mlx-nn-losses-huber-loss"]], "mlx.nn.losses.kl_div_loss": [[329, "mlx-nn-losses-kl-div-loss"]], "mlx.nn.losses.l1_loss": [[330, "mlx-nn-losses-l1-loss"]], "mlx.nn.losses.log_cosh_loss": [[331, "mlx-nn-losses-log-cosh-loss"]], "mlx.nn.losses.margin_ranking_loss": [[332, "mlx-nn-losses-margin-ranking-loss"]], "mlx.nn.losses.mse_loss": [[333, "mlx-nn-losses-mse-loss"]], "mlx.nn.losses.nll_loss": [[334, "mlx-nn-losses-nll-loss"]], "mlx.nn.losses.smooth_l1_loss": [[335, "mlx-nn-losses-smooth-l1-loss"]], "mlx.nn.losses.triplet_loss": [[336, "mlx-nn-losses-triplet-loss"]], "mlx.nn.mish": [[337, "mlx-nn-mish"]], "mlx.nn.prelu": [[338, "mlx-nn-prelu"]], "mlx.nn.relu": [[339, "mlx-nn-relu"]], "mlx.nn.relu6": [[340, "mlx-nn-relu6"]], "mlx.nn.selu": [[341, "mlx-nn-selu"]], "mlx.nn.sigmoid": [[342, "mlx-nn-sigmoid"]], "mlx.nn.silu": [[343, "mlx-nn-silu"]], "mlx.nn.softmax": [[344, "mlx-nn-softmax"]], "mlx.nn.softplus": [[345, "mlx-nn-softplus"]], "mlx.nn.softshrink": [[346, "mlx-nn-softshrink"]], "mlx.nn.step": [[347, "mlx-nn-step"]], "mlx.nn.tanh": [[348, "mlx-nn-tanh"]], "Functions": [[349, "functions"]], "Initializers": [[350, "initializers"]], "Layers": [[351, "layers"]], "Loss Functions": [[352, "loss-functions"]], "Module": [[353, "module"]], "Optimizers": [[355, "optimizers"]], "mlx.optimizers.AdaDelta": [[356, "mlx-optimizers-adadelta"]], "mlx.optimizers.Adafactor": [[357, "mlx-optimizers-adafactor"]], "mlx.optimizers.Adagrad": [[358, "mlx-optimizers-adagrad"]], "mlx.optimizers.Adam": [[359, "mlx-optimizers-adam"]], "mlx.optimizers.AdamW": [[360, "mlx-optimizers-adamw"]], "mlx.optimizers.Adamax": [[361, "mlx-optimizers-adamax"]], "mlx.optimizers.Lion": [[362, "mlx-optimizers-lion"]], "mlx.optimizers.Optimizer.apply_gradients": [[363, "mlx-optimizers-optimizer-apply-gradients"]], "mlx.optimizers.Optimizer.init": [[364, "mlx-optimizers-optimizer-init"]], "mlx.optimizers.Optimizer.state": [[365, "mlx-optimizers-optimizer-state"]], "mlx.optimizers.Optimizer.update": [[366, "mlx-optimizers-optimizer-update"]], "mlx.optimizers.RMSprop": [[367, "mlx-optimizers-rmsprop"]], "mlx.optimizers.SGD": [[368, "mlx-optimizers-sgd"]], "mlx.optimizers.cosine_decay": [[369, "mlx-optimizers-cosine-decay"]], "mlx.optimizers.exponential_decay": [[370, "mlx-optimizers-exponential-decay"]], "mlx.optimizers.join_schedules": [[371, "mlx-optimizers-join-schedules"]], "mlx.optimizers.linear_schedule": [[372, "mlx-optimizers-linear-schedule"]], "mlx.optimizers.step_decay": [[373, "mlx-optimizers-step-decay"]], "Common Optimizers": [[374, "common-optimizers"]], "Optimizer": [[375, "optimizer"]], "Schedulers": [[376, "schedulers"]], "Random": [[377, "random"]], "Transforms": [[378, "transforms"]], "Tree Utils": [[379, "tree-utils"]], "Compilation": [[380, "compilation"]], "Basics of Compile": [[380, "basics-of-compile"]], "Example Speedup": [[380, "example-speedup"]], "Debugging": [[380, "debugging"]], "Pure Functions": [[380, "pure-functions"]], "Compiling Training Graphs": [[380, "compiling-training-graphs"]], "Transformations with Compile": [[380, "transformations-with-compile"]], "Function Transforms": [[381, "function-transforms"]], "Automatic Differentiation": [[381, "automatic-differentiation"]], "Automatic Vectorization": [[381, "automatic-vectorization"]], "Indexing Arrays": [[382, "indexing-arrays"]], "Differences from NumPy": [[382, "differences-from-numpy"]], "In Place Updates": [[382, "in-place-updates"]], "Lazy Evaluation": [[383, "lazy-evaluation"]], "Why Lazy Evaluation": [[383, "why-lazy-evaluation"]], "Transforming Compute Graphs": [[383, "transforming-compute-graphs"]], "Only Compute What You Use": [[383, "only-compute-what-you-use"]], "When to Evaluate": [[383, "when-to-evaluate"]], "Conversion to NumPy and Other Frameworks": [[384, "conversion-to-numpy-and-other-frameworks"]], "PyTorch": [[384, "pytorch"]], "JAX": [[384, "jax"]], "TensorFlow": [[384, "tensorflow"]], "Quick Start Guide": [[385, "quick-start-guide"]], "Basics": [[385, "basics"]], "Function and Graph Transformations": [[385, "function-and-graph-transformations"]], "Saving and Loading Arrays": [[386, "saving-and-loading-arrays"]], "Serialization Formats": [[386, "id1"]], "Unified Memory": [[387, "unified-memory"]], "A Simple Example": [[387, "a-simple-example"]], "Using Streams": [[388, "using-streams"]], "Specifying the Stream": [[388, "specifying-the-stream"]]}, "indexentries": {"device (class in mlx.core)": [[8, "mlx.core.Device"]], "__init__() (device method)": [[8, "mlx.core.Device.__init__"]], "dtype (class in mlx.core)": [[9, "mlx.core.Dtype"]], "__init__() (dtype method)": [[9, "mlx.core.Dtype.__init__"]], "dtypecategory (class in mlx.core)": [[10, "mlx.core.DtypeCategory"]], "__init__() (dtypecategory method)": [[10, "mlx.core.DtypeCategory.__init__"]], "abs() (in module mlx.core)": [[11, "mlx.core.abs"]], "add() (in module mlx.core)": [[12, "mlx.core.add"]], "all() (in module mlx.core)": [[13, "mlx.core.all"]], "allclose() (in module mlx.core)": [[14, "mlx.core.allclose"]], "any() (in module mlx.core)": [[15, "mlx.core.any"]], "arange() (in module mlx.core)": [[16, "mlx.core.arange"]], "arccos() (in module mlx.core)": [[17, "mlx.core.arccos"]], "arccosh() (in module mlx.core)": [[18, "mlx.core.arccosh"]], "arcsin() (in module mlx.core)": [[19, "mlx.core.arcsin"]], "arcsinh() (in module mlx.core)": [[20, "mlx.core.arcsinh"]], "arctan() (in module mlx.core)": [[21, "mlx.core.arctan"]], "arctanh() (in module mlx.core)": [[22, "mlx.core.arctanh"]], "argmax() (in module mlx.core)": [[23, "mlx.core.argmax"]], "argmin() (in module mlx.core)": [[24, "mlx.core.argmin"]], "argpartition() (in module mlx.core)": [[25, "mlx.core.argpartition"]], "argsort() (in module mlx.core)": [[26, "mlx.core.argsort"]], "__init__() (array method)": [[27, "mlx.core.array.__init__"]], "array (class in mlx.core)": [[27, "mlx.core.array"]], "t (array property)": [[28, "mlx.core.array.T"]], "abs() (array method)": [[29, "mlx.core.array.abs"]], "all() (array method)": [[30, "mlx.core.array.all"]], "any() (array method)": [[31, "mlx.core.array.any"]], "argmax() (array method)": [[32, "mlx.core.array.argmax"]], "argmin() (array method)": [[33, "mlx.core.array.argmin"]], "astype() (array method)": [[34, "mlx.core.array.astype"]], "at (array property)": [[35, "mlx.core.array.at"]], "cos() (array method)": [[36, "mlx.core.array.cos"]], "cummax() (array method)": [[37, "mlx.core.array.cummax"]], "cummin() (array method)": [[38, "mlx.core.array.cummin"]], "cumprod() (array method)": [[39, "mlx.core.array.cumprod"]], "cumsum() (array method)": [[40, "mlx.core.array.cumsum"]], "diag() (array method)": [[41, "mlx.core.array.diag"]], "diagonal() (array method)": [[42, "mlx.core.array.diagonal"]], "dtype (array property)": [[43, "mlx.core.array.dtype"]], "exp() (array method)": [[44, "mlx.core.array.exp"]], "flatten() (array method)": [[45, "mlx.core.array.flatten"]], "item() (array method)": [[46, "mlx.core.array.item"]], "itemsize (array property)": [[47, "mlx.core.array.itemsize"]], "log() (array method)": [[48, "mlx.core.array.log"]], "log10() (array method)": [[49, "mlx.core.array.log10"]], "log1p() (array method)": [[50, "mlx.core.array.log1p"]], "log2() (array method)": [[51, "mlx.core.array.log2"]], "logsumexp() (array method)": [[52, "mlx.core.array.logsumexp"]], "max() (array method)": [[53, "mlx.core.array.max"]], "mean() (array method)": [[54, "mlx.core.array.mean"]], "min() (array method)": [[55, "mlx.core.array.min"]], "moveaxis() (array method)": [[56, "mlx.core.array.moveaxis"]], "nbytes (array property)": [[57, "mlx.core.array.nbytes"]], "ndim (array property)": [[58, "mlx.core.array.ndim"]], "prod() (array method)": [[59, "mlx.core.array.prod"]], "reciprocal() (array method)": [[60, "mlx.core.array.reciprocal"]], "reshape() (array method)": [[61, "mlx.core.array.reshape"]], "round() (array method)": [[62, "mlx.core.array.round"]], "rsqrt() (array method)": [[63, "mlx.core.array.rsqrt"]], "shape (array property)": [[64, "mlx.core.array.shape"]], "sin() (array method)": [[65, "mlx.core.array.sin"]], "size (array property)": [[66, "mlx.core.array.size"]], "split() (array method)": [[67, "mlx.core.array.split"]], "sqrt() (array method)": [[68, "mlx.core.array.sqrt"]], "square() (array method)": [[69, "mlx.core.array.square"]], "squeeze() (array method)": [[70, "mlx.core.array.squeeze"]], "sum() (array method)": [[71, "mlx.core.array.sum"]], "swapaxes() (array method)": [[72, "mlx.core.array.swapaxes"]], "tolist() (array method)": [[73, "mlx.core.array.tolist"]], "transpose() (array method)": [[74, "mlx.core.array.transpose"]], "var() (array method)": [[75, "mlx.core.array.var"]], "array_equal() (in module mlx.core)": [[76, "mlx.core.array_equal"]], "atleast_1d() (in module mlx.core)": [[77, "mlx.core.atleast_1d"]], "atleast_2d() (in module mlx.core)": [[78, "mlx.core.atleast_2d"]], "atleast_3d() (in module mlx.core)": [[79, "mlx.core.atleast_3d"]], "broadcast_to() (in module mlx.core)": [[80, "mlx.core.broadcast_to"]], "ceil() (in module mlx.core)": [[81, "mlx.core.ceil"]], "clip() (in module mlx.core)": [[82, "mlx.core.clip"]], "compile() (in module mlx.core)": [[83, "mlx.core.compile"]], "concatenate() (in module mlx.core)": [[84, "mlx.core.concatenate"]], "conv1d() (in module mlx.core)": [[85, "mlx.core.conv1d"]], "conv2d() (in module mlx.core)": [[86, "mlx.core.conv2d"]], "conv_general() (in module mlx.core)": [[87, "mlx.core.conv_general"]], "convolve() (in module mlx.core)": [[88, "mlx.core.convolve"]], "cos() (in module mlx.core)": [[89, "mlx.core.cos"]], "cosh() (in module mlx.core)": [[90, "mlx.core.cosh"]], "cummax() (in module mlx.core)": [[91, "mlx.core.cummax"]], "cummin() (in module mlx.core)": [[92, "mlx.core.cummin"]], "cumprod() (in module mlx.core)": [[93, "mlx.core.cumprod"]], "cumsum() (in module mlx.core)": [[94, "mlx.core.cumsum"]], "default_device() (in module mlx.core)": [[95, "mlx.core.default_device"]], "default_stream() (in module mlx.core)": [[96, "mlx.core.default_stream"]], "dequantize() (in module mlx.core)": [[97, "mlx.core.dequantize"]], "diag() (in module mlx.core)": [[98, "mlx.core.diag"]], "diagonal() (in module mlx.core)": [[99, "mlx.core.diagonal"]], "disable_compile() (in module mlx.core)": [[100, "mlx.core.disable_compile"]], "divide() (in module mlx.core)": [[101, "mlx.core.divide"]], "divmod() (in module mlx.core)": [[102, "mlx.core.divmod"]], "enable_compile() (in module mlx.core)": [[103, "mlx.core.enable_compile"]], "equal() (in module mlx.core)": [[104, "mlx.core.equal"]], "erf() (in module mlx.core)": [[105, "mlx.core.erf"]], "erfinv() (in module mlx.core)": [[106, "mlx.core.erfinv"]], "eval() (in module mlx.core)": [[107, "mlx.core.eval"]], "exp() (in module mlx.core)": [[108, "mlx.core.exp"]], "expand_dims() (in module mlx.core)": [[109, "mlx.core.expand_dims"]], "eye() (in module mlx.core)": [[110, "mlx.core.eye"]], "layer_norm() (in module mlx.core.fast)": [[111, "mlx.core.fast.layer_norm"]], "rms_norm() (in module mlx.core.fast)": [[112, "mlx.core.fast.rms_norm"]], "rope() (in module mlx.core.fast)": [[113, "mlx.core.fast.rope"]], "scaled_dot_product_attention() (in module mlx.core.fast)": [[114, "mlx.core.fast.scaled_dot_product_attention"]], "fft() (in module mlx.core.fft)": [[115, "mlx.core.fft.fft"]], "fft2() (in module mlx.core.fft)": [[116, "mlx.core.fft.fft2"]], "fftn() (in module mlx.core.fft)": [[117, "mlx.core.fft.fftn"]], "ifft() (in module mlx.core.fft)": [[118, "mlx.core.fft.ifft"]], "ifft2() (in module mlx.core.fft)": [[119, "mlx.core.fft.ifft2"]], "ifftn() (in module mlx.core.fft)": [[120, "mlx.core.fft.ifftn"]], "irfft() (in module mlx.core.fft)": [[121, "mlx.core.fft.irfft"]], "irfft2() (in module mlx.core.fft)": [[122, "mlx.core.fft.irfft2"]], "irfftn() (in module mlx.core.fft)": [[123, "mlx.core.fft.irfftn"]], "rfft() (in module mlx.core.fft)": [[124, "mlx.core.fft.rfft"]], "rfft2() (in module mlx.core.fft)": [[125, "mlx.core.fft.rfft2"]], "rfftn() (in module mlx.core.fft)": [[126, "mlx.core.fft.rfftn"]], "flatten() (in module mlx.core)": [[127, "mlx.core.flatten"]], "floor() (in module mlx.core)": [[128, "mlx.core.floor"]], "floor_divide() (in module mlx.core)": [[129, "mlx.core.floor_divide"]], "full() (in module mlx.core)": [[130, "mlx.core.full"]], "grad() (in module mlx.core)": [[131, "mlx.core.grad"]], "greater() (in module mlx.core)": [[132, "mlx.core.greater"]], "greater_equal() (in module mlx.core)": [[133, "mlx.core.greater_equal"]], "identity() (in module mlx.core)": [[134, "mlx.core.identity"]], "inner() (in module mlx.core)": [[135, "mlx.core.inner"]], "isclose() (in module mlx.core)": [[136, "mlx.core.isclose"]], "isinf() (in module mlx.core)": [[137, "mlx.core.isinf"]], "isnan() (in module mlx.core)": [[138, "mlx.core.isnan"]], "isneginf() (in module mlx.core)": [[139, "mlx.core.isneginf"]], "isposinf() (in module mlx.core)": [[140, "mlx.core.isposinf"]], "issubdtype() (in module mlx.core)": [[141, "mlx.core.issubdtype"]], "jvp() (in module mlx.core)": [[142, "mlx.core.jvp"]], "less() (in module mlx.core)": [[143, "mlx.core.less"]], "less_equal() (in module mlx.core)": [[144, "mlx.core.less_equal"]], "norm() (in module mlx.core.linalg)": [[145, "mlx.core.linalg.norm"]], "qr() (in module mlx.core.linalg)": [[146, "mlx.core.linalg.qr"]], "linspace() (in module mlx.core)": [[147, "mlx.core.linspace"]], "load() (in module mlx.core)": [[148, "mlx.core.load"]], "log() (in module mlx.core)": [[149, "mlx.core.log"]], "log10() (in module mlx.core)": [[150, "mlx.core.log10"]], "log1p() (in module mlx.core)": [[151, "mlx.core.log1p"]], "log2() (in module mlx.core)": [[152, "mlx.core.log2"]], "logaddexp() (in module mlx.core)": [[153, "mlx.core.logaddexp"]], "logical_and() (in module mlx.core)": [[154, "mlx.core.logical_and"]], "logical_not() (in module mlx.core)": [[155, "mlx.core.logical_not"]], "logical_or() (in module mlx.core)": [[156, "mlx.core.logical_or"]], "logsumexp() (in module mlx.core)": [[157, "mlx.core.logsumexp"]], "matmul() (in module mlx.core)": [[158, "mlx.core.matmul"]], "max() (in module mlx.core)": [[159, "mlx.core.max"]], "maximum() (in module mlx.core)": [[160, "mlx.core.maximum"]], "mean() (in module mlx.core)": [[161, "mlx.core.mean"]], "get_active_memory() (in module mlx.core.metal)": [[162, "mlx.core.metal.get_active_memory"]], "get_cache_memory() (in module mlx.core.metal)": [[163, "mlx.core.metal.get_cache_memory"]], "get_peak_memory() (in module mlx.core.metal)": [[164, "mlx.core.metal.get_peak_memory"]], "is_available() (in module mlx.core.metal)": [[165, "mlx.core.metal.is_available"]], "set_cache_limit() (in module mlx.core.metal)": [[166, "mlx.core.metal.set_cache_limit"]], "set_memory_limit() (in module mlx.core.metal)": [[167, "mlx.core.metal.set_memory_limit"]], "min() (in module mlx.core)": [[168, "mlx.core.min"]], "minimum() (in module mlx.core)": [[169, "mlx.core.minimum"]], "moveaxis() (in module mlx.core)": [[170, "mlx.core.moveaxis"]], "multiply() (in module mlx.core)": [[171, "mlx.core.multiply"]], "negative() (in module mlx.core)": [[172, "mlx.core.negative"]], "new_stream() (in module mlx.core)": [[173, "mlx.core.new_stream"]], "ones() (in module mlx.core)": [[174, "mlx.core.ones"]], "ones_like() (in module mlx.core)": [[175, "mlx.core.ones_like"]], "outer() (in module mlx.core)": [[176, "mlx.core.outer"]], "pad() (in module mlx.core)": [[177, "mlx.core.pad"]], "partition() (in module mlx.core)": [[178, "mlx.core.partition"]], "prod() (in module mlx.core)": [[179, "mlx.core.prod"]], "quantize() (in module mlx.core)": [[180, "mlx.core.quantize"]], "quantized_matmul() (in module mlx.core)": [[181, "mlx.core.quantized_matmul"]], "bernoulli() (in module mlx.core.random)": [[182, "mlx.core.random.bernoulli"]], "categorical() (in module mlx.core.random)": [[183, "mlx.core.random.categorical"]], "gumbel() (in module mlx.core.random)": [[184, "mlx.core.random.gumbel"]], "key() (in module mlx.core.random)": [[185, "mlx.core.random.key"]], "normal() (in module mlx.core.random)": [[186, "mlx.core.random.normal"]], "randint() (in module mlx.core.random)": [[187, "mlx.core.random.randint"]], "seed() (in module mlx.core.random)": [[188, "mlx.core.random.seed"]], "split() (in module mlx.core.random)": [[189, "mlx.core.random.split"]], "truncated_normal() (in module mlx.core.random)": [[190, "mlx.core.random.truncated_normal"]], "uniform() (in module mlx.core.random)": [[191, "mlx.core.random.uniform"]], "reciprocal() (in module mlx.core)": [[192, "mlx.core.reciprocal"]], "repeat() (in module mlx.core)": [[193, "mlx.core.repeat"]], "reshape() (in module mlx.core)": [[194, "mlx.core.reshape"]], "round() (in module mlx.core)": [[195, "mlx.core.round"]], "rsqrt() (in module mlx.core)": [[196, "mlx.core.rsqrt"]], "save() (in module mlx.core)": [[197, "mlx.core.save"]], "save_gguf() (in module mlx.core)": [[198, "mlx.core.save_gguf"]], "save_safetensors() (in module mlx.core)": [[199, "mlx.core.save_safetensors"]], "savez() (in module mlx.core)": [[200, "mlx.core.savez"]], "savez_compressed() (in module mlx.core)": [[201, "mlx.core.savez_compressed"]], "set_default_device() (in module mlx.core)": [[202, "mlx.core.set_default_device"]], "set_default_stream() (in module mlx.core)": [[203, "mlx.core.set_default_stream"]], "sigmoid() (in module mlx.core)": [[204, "mlx.core.sigmoid"]], "sign() (in module mlx.core)": [[205, "mlx.core.sign"]], "sin() (in module mlx.core)": [[206, "mlx.core.sin"]], "sinh() (in module mlx.core)": [[207, "mlx.core.sinh"]], "softmax() (in module mlx.core)": [[208, "mlx.core.softmax"]], "sort() (in module mlx.core)": [[209, "mlx.core.sort"]], "split() (in module mlx.core)": [[210, "mlx.core.split"]], "sqrt() (in module mlx.core)": [[211, "mlx.core.sqrt"]], "square() (in module mlx.core)": [[212, "mlx.core.square"]], "squeeze() (in module mlx.core)": [[213, "mlx.core.squeeze"]], "stack() (in module mlx.core)": [[214, "mlx.core.stack"]], "stop_gradient() (in module mlx.core)": [[215, "mlx.core.stop_gradient"]], "stream() (in module mlx.core)": [[216, "mlx.core.stream"]], "subtract() (in module mlx.core)": [[217, "mlx.core.subtract"]], "sum() (in module mlx.core)": [[218, "mlx.core.sum"]], "swapaxes() (in module mlx.core)": [[219, "mlx.core.swapaxes"]], "take() (in module mlx.core)": [[220, "mlx.core.take"]], "take_along_axis() (in module mlx.core)": [[221, "mlx.core.take_along_axis"]], "tan() (in module mlx.core)": [[222, "mlx.core.tan"]], "tanh() (in module mlx.core)": [[223, "mlx.core.tanh"]], "tensordot() (in module mlx.core)": [[224, "mlx.core.tensordot"]], "tile() (in module mlx.core)": [[225, "mlx.core.tile"]], "topk() (in module mlx.core)": [[226, "mlx.core.topk"]], "transpose() (in module mlx.core)": [[227, "mlx.core.transpose"]], "tri() (in module mlx.core)": [[228, "mlx.core.tri"]], "tril() (in module mlx.core)": [[229, "mlx.core.tril"]], "triu() (in module mlx.core)": [[230, "mlx.core.triu"]], "value_and_grad() (in module mlx.core)": [[231, "mlx.core.value_and_grad"]], "var() (in module mlx.core)": [[232, "mlx.core.var"]], "vjp() (in module mlx.core)": [[233, "mlx.core.vjp"]], "vmap() (in module mlx.core)": [[234, "mlx.core.vmap"]], "where() (in module mlx.core)": [[235, "mlx.core.where"]], "zeros() (in module mlx.core)": [[236, "mlx.core.zeros"]], "zeros_like() (in module mlx.core)": [[237, "mlx.core.zeros_like"]], "value_and_grad() (in module mlx.nn)": [[238, "mlx.nn.value_and_grad"]], "tree_flatten() (in module mlx.utils)": [[239, "mlx.utils.tree_flatten"]], "tree_map() (in module mlx.utils)": [[240, "mlx.utils.tree_map"]], "tree_unflatten() (in module mlx.utils)": [[241, "mlx.utils.tree_unflatten"]], "stream (class in mlx.core)": [[242, "mlx.core.Stream"]], "__init__() (stream method)": [[242, "mlx.core.Stream.__init__"]], "alibi (class in mlx.nn)": [[251, "mlx.nn.ALiBi"]], "avgpool1d (class in mlx.nn)": [[252, "mlx.nn.AvgPool1d"]], "avgpool2d (class in mlx.nn)": [[253, "mlx.nn.AvgPool2d"]], "batchnorm (class in mlx.nn)": [[254, "mlx.nn.BatchNorm"]], "conv1d (class in mlx.nn)": [[255, "mlx.nn.Conv1d"]], "conv2d (class in mlx.nn)": [[256, "mlx.nn.Conv2d"]], "dropout (class in mlx.nn)": [[257, "mlx.nn.Dropout"]], "dropout2d (class in mlx.nn)": [[258, "mlx.nn.Dropout2d"]], "dropout3d (class in mlx.nn)": [[259, "mlx.nn.Dropout3d"]], "embedding (class in mlx.nn)": [[260, "mlx.nn.Embedding"]], "gelu (class in mlx.nn)": [[261, "mlx.nn.GELU"]], "gru (class in mlx.nn)": [[262, "mlx.nn.GRU"]], "groupnorm (class in mlx.nn)": [[263, "mlx.nn.GroupNorm"]], "instancenorm (class in mlx.nn)": [[264, "mlx.nn.InstanceNorm"]], "lstm (class in mlx.nn)": [[265, "mlx.nn.LSTM"]], "layernorm (class in mlx.nn)": [[266, "mlx.nn.LayerNorm"]], "linear (class in mlx.nn)": [[267, "mlx.nn.Linear"]], "maxpool1d (class in mlx.nn)": [[268, "mlx.nn.MaxPool1d"]], "maxpool2d (class in mlx.nn)": [[269, "mlx.nn.MaxPool2d"]], "mish (class in mlx.nn)": [[270, "mlx.nn.Mish"]], "apply() (module method)": [[271, "mlx.nn.Module.apply"]], "apply_to_modules() (module method)": [[272, "mlx.nn.Module.apply_to_modules"]], "children() (module method)": [[273, "mlx.nn.Module.children"]], "eval() (module method)": [[274, "mlx.nn.Module.eval"]], "filter_and_map() (module method)": [[275, "mlx.nn.Module.filter_and_map"]], "freeze() (module method)": [[276, "mlx.nn.Module.freeze"]], "leaf_modules() (module method)": [[277, "mlx.nn.Module.leaf_modules"]], "load_weights() (module method)": [[278, "mlx.nn.Module.load_weights"]], "modules() (module method)": [[279, "mlx.nn.Module.modules"]], "named_modules() (module method)": [[280, "mlx.nn.Module.named_modules"]], "parameters() (module method)": [[281, "mlx.nn.Module.parameters"]], "save_weights() (module method)": [[282, "mlx.nn.Module.save_weights"]], "set_dtype() (module method)": [[283, "mlx.nn.Module.set_dtype"]], "state (module property)": [[284, "mlx.nn.Module.state"]], "train() (module method)": [[285, "mlx.nn.Module.train"]], "trainable_parameters() (module method)": [[286, "mlx.nn.Module.trainable_parameters"]], "training (module property)": [[287, "mlx.nn.Module.training"]], "unfreeze() (module method)": [[288, "mlx.nn.Module.unfreeze"]], "update() (module method)": [[289, "mlx.nn.Module.update"]], "update_modules() (module method)": [[290, "mlx.nn.Module.update_modules"]], "multiheadattention (class in mlx.nn)": [[291, "mlx.nn.MultiHeadAttention"]], "prelu (class in mlx.nn)": [[292, "mlx.nn.PReLU"]], "quantizedlinear (class in mlx.nn)": [[293, "mlx.nn.QuantizedLinear"]], "rmsnorm (class in mlx.nn)": [[294, "mlx.nn.RMSNorm"]], "rnn (class in mlx.nn)": [[295, "mlx.nn.RNN"]], "relu (class in mlx.nn)": [[296, "mlx.nn.ReLU"]], "rope (class in mlx.nn)": [[297, "mlx.nn.RoPE"]], "selu (class in mlx.nn)": [[298, "mlx.nn.SELU"]], "sequential (class in mlx.nn)": [[299, "mlx.nn.Sequential"]], "silu (class in mlx.nn)": [[300, "mlx.nn.SiLU"]], "sinusoidalpositionalencoding (class in mlx.nn)": [[301, "mlx.nn.SinusoidalPositionalEncoding"]], "softshrink (class in mlx.nn)": [[302, "mlx.nn.Softshrink"]], "step (class in mlx.nn)": [[303, "mlx.nn.Step"]], "transformer (class in mlx.nn)": [[304, "mlx.nn.Transformer"]], "upsample (class in mlx.nn)": [[305, "mlx.nn.Upsample"]], "constant() (in module mlx.nn.init)": [[306, "mlx.nn.init.constant"]], "glorot_normal() (in module mlx.nn.init)": [[307, "mlx.nn.init.glorot_normal"]], "glorot_uniform() (in module mlx.nn.init)": [[308, "mlx.nn.init.glorot_uniform"]], "he_normal() (in module mlx.nn.init)": [[309, "mlx.nn.init.he_normal"]], "he_uniform() (in module mlx.nn.init)": [[310, "mlx.nn.init.he_uniform"]], "identity() (in module mlx.nn.init)": [[311, "mlx.nn.init.identity"]], "normal() (in module mlx.nn.init)": [[312, "mlx.nn.init.normal"]], "uniform() (in module mlx.nn.init)": [[313, "mlx.nn.init.uniform"]], "elu() (in module mlx.nn)": [[314, "mlx.nn.elu"]], "gelu() (in module mlx.nn)": [[315, "mlx.nn.gelu"]], "gelu_approx() (in module mlx.nn)": [[316, "mlx.nn.gelu_approx"]], "gelu_fast_approx() (in module mlx.nn)": [[317, "mlx.nn.gelu_fast_approx"]], "glu() (in module mlx.nn)": [[318, "mlx.nn.glu"]], "hardswish() (in module mlx.nn)": [[319, "mlx.nn.hardswish"]], "leaky_relu() (in module mlx.nn)": [[320, "mlx.nn.leaky_relu"]], "log_sigmoid() (in module mlx.nn)": [[321, "mlx.nn.log_sigmoid"]], "log_softmax() (in module mlx.nn)": [[322, "mlx.nn.log_softmax"]], "binary_cross_entropy() (in module mlx.nn.losses)": [[323, "mlx.nn.losses.binary_cross_entropy"]], "cosine_similarity_loss() (in module mlx.nn.losses)": [[324, "mlx.nn.losses.cosine_similarity_loss"]], "cross_entropy() (in module mlx.nn.losses)": [[325, "mlx.nn.losses.cross_entropy"]], "gaussian_nll_loss() (in module mlx.nn.losses)": [[326, "mlx.nn.losses.gaussian_nll_loss"]], "hinge_loss() (in module mlx.nn.losses)": [[327, "mlx.nn.losses.hinge_loss"]], "huber_loss() (in module mlx.nn.losses)": [[328, "mlx.nn.losses.huber_loss"]], "kl_div_loss() (in module mlx.nn.losses)": [[329, "mlx.nn.losses.kl_div_loss"]], "l1_loss() (in module mlx.nn.losses)": [[330, "mlx.nn.losses.l1_loss"]], "log_cosh_loss() (in module mlx.nn.losses)": [[331, "mlx.nn.losses.log_cosh_loss"]], "margin_ranking_loss() (in module mlx.nn.losses)": [[332, "mlx.nn.losses.margin_ranking_loss"]], "mse_loss() (in module mlx.nn.losses)": [[333, "mlx.nn.losses.mse_loss"]], "nll_loss() (in module mlx.nn.losses)": [[334, "mlx.nn.losses.nll_loss"]], "smooth_l1_loss() (in module mlx.nn.losses)": [[335, "mlx.nn.losses.smooth_l1_loss"]], "triplet_loss() (in module mlx.nn.losses)": [[336, "mlx.nn.losses.triplet_loss"]], "mish() (in module mlx.nn)": [[337, "mlx.nn.mish"]], "prelu() (in module mlx.nn)": [[338, "mlx.nn.prelu"]], "relu() (in module mlx.nn)": [[339, "mlx.nn.relu"]], "relu6() (in module mlx.nn)": [[340, "mlx.nn.relu6"]], "selu() (in module mlx.nn)": [[341, "mlx.nn.selu"]], "sigmoid() (in module mlx.nn)": [[342, "mlx.nn.sigmoid"]], "silu() (in module mlx.nn)": [[343, "mlx.nn.silu"]], "softmax() (in module mlx.nn)": [[344, "mlx.nn.softmax"]], "softplus() (in module mlx.nn)": [[345, "mlx.nn.softplus"]], "softshrink() (in module mlx.nn)": [[346, "mlx.nn.softshrink"]], "step() (in module mlx.nn)": [[347, "mlx.nn.step"]], "tanh() (in module mlx.nn)": [[348, "mlx.nn.tanh"]], "module (class in mlx.nn)": [[353, "mlx.nn.Module"]], "adadelta (class in mlx.optimizers)": [[356, "mlx.optimizers.AdaDelta"]], "adafactor (class in mlx.optimizers)": [[357, "mlx.optimizers.Adafactor"]], "adagrad (class in mlx.optimizers)": [[358, "mlx.optimizers.Adagrad"]], "adam (class in mlx.optimizers)": [[359, "mlx.optimizers.Adam"]], "adamw (class in mlx.optimizers)": [[360, "mlx.optimizers.AdamW"]], "adamax (class in mlx.optimizers)": [[361, "mlx.optimizers.Adamax"]], "lion (class in mlx.optimizers)": [[362, "mlx.optimizers.Lion"]], "apply_gradients() (optimizer method)": [[363, "mlx.optimizers.Optimizer.apply_gradients"]], "init() (optimizer method)": [[364, "mlx.optimizers.Optimizer.init"]], "state (optimizer property)": [[365, "mlx.optimizers.Optimizer.state"]], "update() (optimizer method)": [[366, "mlx.optimizers.Optimizer.update"]], "rmsprop (class in mlx.optimizers)": [[367, "mlx.optimizers.RMSprop"]], "sgd (class in mlx.optimizers)": [[368, "mlx.optimizers.SGD"]], "cosine_decay() (in module mlx.optimizers)": [[369, "mlx.optimizers.cosine_decay"]], "exponential_decay() (in module mlx.optimizers)": [[370, "mlx.optimizers.exponential_decay"]], "join_schedules() (in module mlx.optimizers)": [[371, "mlx.optimizers.join_schedules"]], "linear_schedule() (in module mlx.optimizers)": [[372, "mlx.optimizers.linear_schedule"]], "step_decay() (in module mlx.optimizers)": [[373, "mlx.optimizers.step_decay"]], "optimizer (class in mlx.optimizers)": [[375, "mlx.optimizers.Optimizer"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["cpp/ops", "dev/extensions", "dev/metal_debugger", "examples/linear_regression", "examples/llama-inference", "examples/mlp", "index", "install", "python/_autosummary/mlx.core.Device", "python/_autosummary/mlx.core.Dtype", "python/_autosummary/mlx.core.DtypeCategory", "python/_autosummary/mlx.core.abs", "python/_autosummary/mlx.core.add", "python/_autosummary/mlx.core.all", "python/_autosummary/mlx.core.allclose", "python/_autosummary/mlx.core.any", "python/_autosummary/mlx.core.arange", "python/_autosummary/mlx.core.arccos", "python/_autosummary/mlx.core.arccosh", "python/_autosummary/mlx.core.arcsin", "python/_autosummary/mlx.core.arcsinh", "python/_autosummary/mlx.core.arctan", "python/_autosummary/mlx.core.arctanh", "python/_autosummary/mlx.core.argmax", "python/_autosummary/mlx.core.argmin", "python/_autosummary/mlx.core.argpartition", "python/_autosummary/mlx.core.argsort", "python/_autosummary/mlx.core.array", "python/_autosummary/mlx.core.array.T", "python/_autosummary/mlx.core.array.abs", "python/_autosummary/mlx.core.array.all", "python/_autosummary/mlx.core.array.any", "python/_autosummary/mlx.core.array.argmax", "python/_autosummary/mlx.core.array.argmin", "python/_autosummary/mlx.core.array.astype", "python/_autosummary/mlx.core.array.at", "python/_autosummary/mlx.core.array.cos", "python/_autosummary/mlx.core.array.cummax", "python/_autosummary/mlx.core.array.cummin", "python/_autosummary/mlx.core.array.cumprod", "python/_autosummary/mlx.core.array.cumsum", "python/_autosummary/mlx.core.array.diag", "python/_autosummary/mlx.core.array.diagonal", "python/_autosummary/mlx.core.array.dtype", "python/_autosummary/mlx.core.array.exp", "python/_autosummary/mlx.core.array.flatten", "python/_autosummary/mlx.core.array.item", "python/_autosummary/mlx.core.array.itemsize", "python/_autosummary/mlx.core.array.log", "python/_autosummary/mlx.core.array.log10", "python/_autosummary/mlx.core.array.log1p", "python/_autosummary/mlx.core.array.log2", "python/_autosummary/mlx.core.array.logsumexp", "python/_autosummary/mlx.core.array.max", "python/_autosummary/mlx.core.array.mean", "python/_autosummary/mlx.core.array.min", "python/_autosummary/mlx.core.array.moveaxis", "python/_autosummary/mlx.core.array.nbytes", "python/_autosummary/mlx.core.array.ndim", "python/_autosummary/mlx.core.array.prod", "python/_autosummary/mlx.core.array.reciprocal", "python/_autosummary/mlx.core.array.reshape", "python/_autosummary/mlx.core.array.round", "python/_autosummary/mlx.core.array.rsqrt", "python/_autosummary/mlx.core.array.shape", "python/_autosummary/mlx.core.array.sin", "python/_autosummary/mlx.core.array.size", "python/_autosummary/mlx.core.array.split", "python/_autosummary/mlx.core.array.sqrt", "python/_autosummary/mlx.core.array.square", "python/_autosummary/mlx.core.array.squeeze", "python/_autosummary/mlx.core.array.sum", "python/_autosummary/mlx.core.array.swapaxes", "python/_autosummary/mlx.core.array.tolist", "python/_autosummary/mlx.core.array.transpose", "python/_autosummary/mlx.core.array.var", "python/_autosummary/mlx.core.array_equal", "python/_autosummary/mlx.core.atleast_1d", "python/_autosummary/mlx.core.atleast_2d", "python/_autosummary/mlx.core.atleast_3d", "python/_autosummary/mlx.core.broadcast_to", "python/_autosummary/mlx.core.ceil", "python/_autosummary/mlx.core.clip", "python/_autosummary/mlx.core.compile", "python/_autosummary/mlx.core.concatenate", "python/_autosummary/mlx.core.conv1d", "python/_autosummary/mlx.core.conv2d", "python/_autosummary/mlx.core.conv_general", "python/_autosummary/mlx.core.convolve", "python/_autosummary/mlx.core.cos", "python/_autosummary/mlx.core.cosh", "python/_autosummary/mlx.core.cummax", "python/_autosummary/mlx.core.cummin", "python/_autosummary/mlx.core.cumprod", "python/_autosummary/mlx.core.cumsum", "python/_autosummary/mlx.core.default_device", "python/_autosummary/mlx.core.default_stream", "python/_autosummary/mlx.core.dequantize", "python/_autosummary/mlx.core.diag", "python/_autosummary/mlx.core.diagonal", "python/_autosummary/mlx.core.disable_compile", "python/_autosummary/mlx.core.divide", "python/_autosummary/mlx.core.divmod", "python/_autosummary/mlx.core.enable_compile", "python/_autosummary/mlx.core.equal", "python/_autosummary/mlx.core.erf", "python/_autosummary/mlx.core.erfinv", "python/_autosummary/mlx.core.eval", "python/_autosummary/mlx.core.exp", "python/_autosummary/mlx.core.expand_dims", "python/_autosummary/mlx.core.expm1", "python/_autosummary/mlx.core.eye", "python/_autosummary/mlx.core.fast.layer_norm", "python/_autosummary/mlx.core.fast.rms_norm", "python/_autosummary/mlx.core.fast.rope", "python/_autosummary/mlx.core.fast.scaled_dot_product_attention", "python/_autosummary/mlx.core.fft.fft", "python/_autosummary/mlx.core.fft.fft2", "python/_autosummary/mlx.core.fft.fftn", "python/_autosummary/mlx.core.fft.ifft", "python/_autosummary/mlx.core.fft.ifft2", "python/_autosummary/mlx.core.fft.ifftn", "python/_autosummary/mlx.core.fft.irfft", "python/_autosummary/mlx.core.fft.irfft2", "python/_autosummary/mlx.core.fft.irfftn", "python/_autosummary/mlx.core.fft.rfft", "python/_autosummary/mlx.core.fft.rfft2", "python/_autosummary/mlx.core.fft.rfftn", "python/_autosummary/mlx.core.flatten", "python/_autosummary/mlx.core.floor", "python/_autosummary/mlx.core.floor_divide", "python/_autosummary/mlx.core.full", "python/_autosummary/mlx.core.grad", "python/_autosummary/mlx.core.greater", "python/_autosummary/mlx.core.greater_equal", "python/_autosummary/mlx.core.identity", "python/_autosummary/mlx.core.inner", "python/_autosummary/mlx.core.isclose", "python/_autosummary/mlx.core.isinf", "python/_autosummary/mlx.core.isnan", "python/_autosummary/mlx.core.isneginf", "python/_autosummary/mlx.core.isposinf", "python/_autosummary/mlx.core.issubdtype", "python/_autosummary/mlx.core.jvp", "python/_autosummary/mlx.core.less", "python/_autosummary/mlx.core.less_equal", "python/_autosummary/mlx.core.linalg.norm", "python/_autosummary/mlx.core.linalg.qr", "python/_autosummary/mlx.core.linspace", "python/_autosummary/mlx.core.load", "python/_autosummary/mlx.core.log", "python/_autosummary/mlx.core.log10", "python/_autosummary/mlx.core.log1p", "python/_autosummary/mlx.core.log2", "python/_autosummary/mlx.core.logaddexp", "python/_autosummary/mlx.core.logical_and", "python/_autosummary/mlx.core.logical_not", "python/_autosummary/mlx.core.logical_or", "python/_autosummary/mlx.core.logsumexp", "python/_autosummary/mlx.core.matmul", "python/_autosummary/mlx.core.max", "python/_autosummary/mlx.core.maximum", "python/_autosummary/mlx.core.mean", "python/_autosummary/mlx.core.meshgrid", "python/_autosummary/mlx.core.metal.get_active_memory", "python/_autosummary/mlx.core.metal.get_cache_memory", "python/_autosummary/mlx.core.metal.get_peak_memory", "python/_autosummary/mlx.core.metal.is_available", "python/_autosummary/mlx.core.metal.set_cache_limit", "python/_autosummary/mlx.core.metal.set_memory_limit", "python/_autosummary/mlx.core.metal.start_capture", "python/_autosummary/mlx.core.metal.stop_capture", "python/_autosummary/mlx.core.min", "python/_autosummary/mlx.core.minimum", "python/_autosummary/mlx.core.moveaxis", "python/_autosummary/mlx.core.multiply", "python/_autosummary/mlx.core.negative", "python/_autosummary/mlx.core.new_stream", "python/_autosummary/mlx.core.ones", "python/_autosummary/mlx.core.ones_like", "python/_autosummary/mlx.core.outer", "python/_autosummary/mlx.core.pad", "python/_autosummary/mlx.core.partition", "python/_autosummary/mlx.core.prod", "python/_autosummary/mlx.core.quantize", "python/_autosummary/mlx.core.quantized_matmul", "python/_autosummary/mlx.core.random.bernoulli", "python/_autosummary/mlx.core.random.categorical", "python/_autosummary/mlx.core.random.gumbel", "python/_autosummary/mlx.core.random.key", "python/_autosummary/mlx.core.random.multivariate_normal", "python/_autosummary/mlx.core.random.normal", "python/_autosummary/mlx.core.random.randint", "python/_autosummary/mlx.core.random.seed", "python/_autosummary/mlx.core.random.split", "python/_autosummary/mlx.core.random.truncated_normal", "python/_autosummary/mlx.core.random.uniform", "python/_autosummary/mlx.core.reciprocal", "python/_autosummary/mlx.core.repeat", "python/_autosummary/mlx.core.reshape", "python/_autosummary/mlx.core.round", "python/_autosummary/mlx.core.rsqrt", "python/_autosummary/mlx.core.save", "python/_autosummary/mlx.core.save_gguf", "python/_autosummary/mlx.core.save_safetensors", "python/_autosummary/mlx.core.savez", "python/_autosummary/mlx.core.savez_compressed", "python/_autosummary/mlx.core.set_default_device", "python/_autosummary/mlx.core.set_default_stream", "python/_autosummary/mlx.core.sigmoid", "python/_autosummary/mlx.core.sign", "python/_autosummary/mlx.core.sin", "python/_autosummary/mlx.core.sinh", "python/_autosummary/mlx.core.softmax", "python/_autosummary/mlx.core.sort", "python/_autosummary/mlx.core.split", "python/_autosummary/mlx.core.sqrt", "python/_autosummary/mlx.core.square", "python/_autosummary/mlx.core.squeeze", "python/_autosummary/mlx.core.stack", "python/_autosummary/mlx.core.std", "python/_autosummary/mlx.core.stop_gradient", "python/_autosummary/mlx.core.stream", "python/_autosummary/mlx.core.subtract", "python/_autosummary/mlx.core.sum", "python/_autosummary/mlx.core.swapaxes", "python/_autosummary/mlx.core.take", "python/_autosummary/mlx.core.take_along_axis", "python/_autosummary/mlx.core.tan", "python/_autosummary/mlx.core.tanh", "python/_autosummary/mlx.core.tensordot", "python/_autosummary/mlx.core.tile", "python/_autosummary/mlx.core.topk", "python/_autosummary/mlx.core.transpose", "python/_autosummary/mlx.core.tri", "python/_autosummary/mlx.core.tril", "python/_autosummary/mlx.core.triu", "python/_autosummary/mlx.core.value_and_grad", "python/_autosummary/mlx.core.var", "python/_autosummary/mlx.core.vjp", "python/_autosummary/mlx.core.vmap", "python/_autosummary/mlx.core.where", "python/_autosummary/mlx.core.zeros", "python/_autosummary/mlx.core.zeros_like", "python/_autosummary/mlx.nn.value_and_grad", "python/_autosummary/mlx.utils.tree_flatten", "python/_autosummary/mlx.utils.tree_map", "python/_autosummary/mlx.utils.tree_unflatten", "python/_autosummary/stream_class", "python/array", "python/data_types", "python/devices_and_streams", "python/fast", "python/fft", "python/linalg", "python/metal", "python/nn", "python/nn/_autosummary/mlx.nn.ALiBi", "python/nn/_autosummary/mlx.nn.AvgPool1d", "python/nn/_autosummary/mlx.nn.AvgPool2d", "python/nn/_autosummary/mlx.nn.BatchNorm", "python/nn/_autosummary/mlx.nn.Conv1d", "python/nn/_autosummary/mlx.nn.Conv2d", "python/nn/_autosummary/mlx.nn.Dropout", "python/nn/_autosummary/mlx.nn.Dropout2d", "python/nn/_autosummary/mlx.nn.Dropout3d", "python/nn/_autosummary/mlx.nn.Embedding", "python/nn/_autosummary/mlx.nn.GELU", "python/nn/_autosummary/mlx.nn.GRU", "python/nn/_autosummary/mlx.nn.GroupNorm", "python/nn/_autosummary/mlx.nn.InstanceNorm", "python/nn/_autosummary/mlx.nn.LSTM", "python/nn/_autosummary/mlx.nn.LayerNorm", "python/nn/_autosummary/mlx.nn.Linear", "python/nn/_autosummary/mlx.nn.MaxPool1d", "python/nn/_autosummary/mlx.nn.MaxPool2d", "python/nn/_autosummary/mlx.nn.Mish", "python/nn/_autosummary/mlx.nn.Module.apply", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules", "python/nn/_autosummary/mlx.nn.Module.children", "python/nn/_autosummary/mlx.nn.Module.eval", "python/nn/_autosummary/mlx.nn.Module.filter_and_map", "python/nn/_autosummary/mlx.nn.Module.freeze", "python/nn/_autosummary/mlx.nn.Module.leaf_modules", "python/nn/_autosummary/mlx.nn.Module.load_weights", "python/nn/_autosummary/mlx.nn.Module.modules", "python/nn/_autosummary/mlx.nn.Module.named_modules", "python/nn/_autosummary/mlx.nn.Module.parameters", "python/nn/_autosummary/mlx.nn.Module.save_weights", "python/nn/_autosummary/mlx.nn.Module.set_dtype", "python/nn/_autosummary/mlx.nn.Module.state", "python/nn/_autosummary/mlx.nn.Module.train", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters", "python/nn/_autosummary/mlx.nn.Module.training", "python/nn/_autosummary/mlx.nn.Module.unfreeze", "python/nn/_autosummary/mlx.nn.Module.update", "python/nn/_autosummary/mlx.nn.Module.update_modules", "python/nn/_autosummary/mlx.nn.MultiHeadAttention", "python/nn/_autosummary/mlx.nn.PReLU", "python/nn/_autosummary/mlx.nn.QuantizedLinear", "python/nn/_autosummary/mlx.nn.RMSNorm", "python/nn/_autosummary/mlx.nn.RNN", "python/nn/_autosummary/mlx.nn.ReLU", "python/nn/_autosummary/mlx.nn.RoPE", "python/nn/_autosummary/mlx.nn.SELU", "python/nn/_autosummary/mlx.nn.Sequential", "python/nn/_autosummary/mlx.nn.SiLU", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding", "python/nn/_autosummary/mlx.nn.Softshrink", "python/nn/_autosummary/mlx.nn.Step", "python/nn/_autosummary/mlx.nn.Transformer", "python/nn/_autosummary/mlx.nn.Upsample", "python/nn/_autosummary/mlx.nn.init.constant", "python/nn/_autosummary/mlx.nn.init.glorot_normal", "python/nn/_autosummary/mlx.nn.init.glorot_uniform", "python/nn/_autosummary/mlx.nn.init.he_normal", "python/nn/_autosummary/mlx.nn.init.he_uniform", "python/nn/_autosummary/mlx.nn.init.identity", "python/nn/_autosummary/mlx.nn.init.normal", "python/nn/_autosummary/mlx.nn.init.uniform", "python/nn/_autosummary_functions/mlx.nn.elu", "python/nn/_autosummary_functions/mlx.nn.gelu", "python/nn/_autosummary_functions/mlx.nn.gelu_approx", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx", "python/nn/_autosummary_functions/mlx.nn.glu", "python/nn/_autosummary_functions/mlx.nn.hardswish", "python/nn/_autosummary_functions/mlx.nn.leaky_relu", "python/nn/_autosummary_functions/mlx.nn.log_sigmoid", "python/nn/_autosummary_functions/mlx.nn.log_softmax", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss", "python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss", "python/nn/_autosummary_functions/mlx.nn.mish", "python/nn/_autosummary_functions/mlx.nn.prelu", "python/nn/_autosummary_functions/mlx.nn.relu", "python/nn/_autosummary_functions/mlx.nn.relu6", "python/nn/_autosummary_functions/mlx.nn.selu", "python/nn/_autosummary_functions/mlx.nn.sigmoid", "python/nn/_autosummary_functions/mlx.nn.silu", "python/nn/_autosummary_functions/mlx.nn.softmax", "python/nn/_autosummary_functions/mlx.nn.softplus", "python/nn/_autosummary_functions/mlx.nn.softshrink", "python/nn/_autosummary_functions/mlx.nn.step", "python/nn/_autosummary_functions/mlx.nn.tanh", "python/nn/functions", "python/nn/init", "python/nn/layers", "python/nn/losses", "python/nn/module", "python/ops", "python/optimizers", "python/optimizers/_autosummary/mlx.optimizers.AdaDelta", "python/optimizers/_autosummary/mlx.optimizers.Adafactor", "python/optimizers/_autosummary/mlx.optimizers.Adagrad", "python/optimizers/_autosummary/mlx.optimizers.Adam", "python/optimizers/_autosummary/mlx.optimizers.AdamW", "python/optimizers/_autosummary/mlx.optimizers.Adamax", "python/optimizers/_autosummary/mlx.optimizers.Lion", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.init", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.state", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.update", "python/optimizers/_autosummary/mlx.optimizers.RMSprop", "python/optimizers/_autosummary/mlx.optimizers.SGD", "python/optimizers/_autosummary/mlx.optimizers.cosine_decay", "python/optimizers/_autosummary/mlx.optimizers.exponential_decay", "python/optimizers/_autosummary/mlx.optimizers.join_schedules", "python/optimizers/_autosummary/mlx.optimizers.linear_schedule", "python/optimizers/_autosummary/mlx.optimizers.step_decay", "python/optimizers/common_optimizers", "python/optimizers/optimizer", "python/optimizers/schedulers", "python/random", "python/transforms", "python/tree_utils", "usage/compile", "usage/function_transforms", "usage/indexing", "usage/lazy_evaluation", "usage/numpy", "usage/quick_start", "usage/saving_and_loading", "usage/unified_memory", "usage/using_streams"], "filenames": ["cpp/ops.rst", "dev/extensions.rst", "dev/metal_debugger.rst", "examples/linear_regression.rst", "examples/llama-inference.rst", "examples/mlp.rst", "index.rst", "install.rst", "python/_autosummary/mlx.core.Device.rst", "python/_autosummary/mlx.core.Dtype.rst", "python/_autosummary/mlx.core.DtypeCategory.rst", "python/_autosummary/mlx.core.abs.rst", "python/_autosummary/mlx.core.add.rst", "python/_autosummary/mlx.core.all.rst", "python/_autosummary/mlx.core.allclose.rst", "python/_autosummary/mlx.core.any.rst", "python/_autosummary/mlx.core.arange.rst", "python/_autosummary/mlx.core.arccos.rst", "python/_autosummary/mlx.core.arccosh.rst", "python/_autosummary/mlx.core.arcsin.rst", "python/_autosummary/mlx.core.arcsinh.rst", "python/_autosummary/mlx.core.arctan.rst", "python/_autosummary/mlx.core.arctanh.rst", "python/_autosummary/mlx.core.argmax.rst", "python/_autosummary/mlx.core.argmin.rst", "python/_autosummary/mlx.core.argpartition.rst", "python/_autosummary/mlx.core.argsort.rst", "python/_autosummary/mlx.core.array.rst", "python/_autosummary/mlx.core.array.T.rst", "python/_autosummary/mlx.core.array.abs.rst", "python/_autosummary/mlx.core.array.all.rst", "python/_autosummary/mlx.core.array.any.rst", "python/_autosummary/mlx.core.array.argmax.rst", "python/_autosummary/mlx.core.array.argmin.rst", "python/_autosummary/mlx.core.array.astype.rst", "python/_autosummary/mlx.core.array.at.rst", "python/_autosummary/mlx.core.array.cos.rst", "python/_autosummary/mlx.core.array.cummax.rst", "python/_autosummary/mlx.core.array.cummin.rst", "python/_autosummary/mlx.core.array.cumprod.rst", "python/_autosummary/mlx.core.array.cumsum.rst", "python/_autosummary/mlx.core.array.diag.rst", "python/_autosummary/mlx.core.array.diagonal.rst", "python/_autosummary/mlx.core.array.dtype.rst", "python/_autosummary/mlx.core.array.exp.rst", "python/_autosummary/mlx.core.array.flatten.rst", "python/_autosummary/mlx.core.array.item.rst", "python/_autosummary/mlx.core.array.itemsize.rst", "python/_autosummary/mlx.core.array.log.rst", "python/_autosummary/mlx.core.array.log10.rst", "python/_autosummary/mlx.core.array.log1p.rst", "python/_autosummary/mlx.core.array.log2.rst", "python/_autosummary/mlx.core.array.logsumexp.rst", "python/_autosummary/mlx.core.array.max.rst", "python/_autosummary/mlx.core.array.mean.rst", "python/_autosummary/mlx.core.array.min.rst", "python/_autosummary/mlx.core.array.moveaxis.rst", "python/_autosummary/mlx.core.array.nbytes.rst", "python/_autosummary/mlx.core.array.ndim.rst", "python/_autosummary/mlx.core.array.prod.rst", "python/_autosummary/mlx.core.array.reciprocal.rst", "python/_autosummary/mlx.core.array.reshape.rst", "python/_autosummary/mlx.core.array.round.rst", "python/_autosummary/mlx.core.array.rsqrt.rst", "python/_autosummary/mlx.core.array.shape.rst", "python/_autosummary/mlx.core.array.sin.rst", "python/_autosummary/mlx.core.array.size.rst", "python/_autosummary/mlx.core.array.split.rst", "python/_autosummary/mlx.core.array.sqrt.rst", "python/_autosummary/mlx.core.array.square.rst", "python/_autosummary/mlx.core.array.squeeze.rst", "python/_autosummary/mlx.core.array.sum.rst", "python/_autosummary/mlx.core.array.swapaxes.rst", "python/_autosummary/mlx.core.array.tolist.rst", "python/_autosummary/mlx.core.array.transpose.rst", "python/_autosummary/mlx.core.array.var.rst", "python/_autosummary/mlx.core.array_equal.rst", "python/_autosummary/mlx.core.atleast_1d.rst", "python/_autosummary/mlx.core.atleast_2d.rst", "python/_autosummary/mlx.core.atleast_3d.rst", "python/_autosummary/mlx.core.broadcast_to.rst", "python/_autosummary/mlx.core.ceil.rst", "python/_autosummary/mlx.core.clip.rst", "python/_autosummary/mlx.core.compile.rst", "python/_autosummary/mlx.core.concatenate.rst", "python/_autosummary/mlx.core.conv1d.rst", "python/_autosummary/mlx.core.conv2d.rst", "python/_autosummary/mlx.core.conv_general.rst", "python/_autosummary/mlx.core.convolve.rst", "python/_autosummary/mlx.core.cos.rst", "python/_autosummary/mlx.core.cosh.rst", "python/_autosummary/mlx.core.cummax.rst", "python/_autosummary/mlx.core.cummin.rst", "python/_autosummary/mlx.core.cumprod.rst", "python/_autosummary/mlx.core.cumsum.rst", "python/_autosummary/mlx.core.default_device.rst", "python/_autosummary/mlx.core.default_stream.rst", "python/_autosummary/mlx.core.dequantize.rst", "python/_autosummary/mlx.core.diag.rst", "python/_autosummary/mlx.core.diagonal.rst", "python/_autosummary/mlx.core.disable_compile.rst", "python/_autosummary/mlx.core.divide.rst", "python/_autosummary/mlx.core.divmod.rst", "python/_autosummary/mlx.core.enable_compile.rst", "python/_autosummary/mlx.core.equal.rst", "python/_autosummary/mlx.core.erf.rst", "python/_autosummary/mlx.core.erfinv.rst", "python/_autosummary/mlx.core.eval.rst", "python/_autosummary/mlx.core.exp.rst", "python/_autosummary/mlx.core.expand_dims.rst", "python/_autosummary/mlx.core.expm1.rst", "python/_autosummary/mlx.core.eye.rst", "python/_autosummary/mlx.core.fast.layer_norm.rst", "python/_autosummary/mlx.core.fast.rms_norm.rst", "python/_autosummary/mlx.core.fast.rope.rst", "python/_autosummary/mlx.core.fast.scaled_dot_product_attention.rst", "python/_autosummary/mlx.core.fft.fft.rst", "python/_autosummary/mlx.core.fft.fft2.rst", "python/_autosummary/mlx.core.fft.fftn.rst", "python/_autosummary/mlx.core.fft.ifft.rst", "python/_autosummary/mlx.core.fft.ifft2.rst", "python/_autosummary/mlx.core.fft.ifftn.rst", "python/_autosummary/mlx.core.fft.irfft.rst", "python/_autosummary/mlx.core.fft.irfft2.rst", "python/_autosummary/mlx.core.fft.irfftn.rst", "python/_autosummary/mlx.core.fft.rfft.rst", "python/_autosummary/mlx.core.fft.rfft2.rst", "python/_autosummary/mlx.core.fft.rfftn.rst", "python/_autosummary/mlx.core.flatten.rst", "python/_autosummary/mlx.core.floor.rst", "python/_autosummary/mlx.core.floor_divide.rst", "python/_autosummary/mlx.core.full.rst", "python/_autosummary/mlx.core.grad.rst", "python/_autosummary/mlx.core.greater.rst", "python/_autosummary/mlx.core.greater_equal.rst", "python/_autosummary/mlx.core.identity.rst", "python/_autosummary/mlx.core.inner.rst", "python/_autosummary/mlx.core.isclose.rst", "python/_autosummary/mlx.core.isinf.rst", "python/_autosummary/mlx.core.isnan.rst", "python/_autosummary/mlx.core.isneginf.rst", "python/_autosummary/mlx.core.isposinf.rst", "python/_autosummary/mlx.core.issubdtype.rst", "python/_autosummary/mlx.core.jvp.rst", "python/_autosummary/mlx.core.less.rst", "python/_autosummary/mlx.core.less_equal.rst", "python/_autosummary/mlx.core.linalg.norm.rst", "python/_autosummary/mlx.core.linalg.qr.rst", "python/_autosummary/mlx.core.linspace.rst", "python/_autosummary/mlx.core.load.rst", "python/_autosummary/mlx.core.log.rst", "python/_autosummary/mlx.core.log10.rst", "python/_autosummary/mlx.core.log1p.rst", "python/_autosummary/mlx.core.log2.rst", "python/_autosummary/mlx.core.logaddexp.rst", "python/_autosummary/mlx.core.logical_and.rst", "python/_autosummary/mlx.core.logical_not.rst", "python/_autosummary/mlx.core.logical_or.rst", "python/_autosummary/mlx.core.logsumexp.rst", "python/_autosummary/mlx.core.matmul.rst", "python/_autosummary/mlx.core.max.rst", "python/_autosummary/mlx.core.maximum.rst", "python/_autosummary/mlx.core.mean.rst", "python/_autosummary/mlx.core.meshgrid.rst", "python/_autosummary/mlx.core.metal.get_active_memory.rst", "python/_autosummary/mlx.core.metal.get_cache_memory.rst", "python/_autosummary/mlx.core.metal.get_peak_memory.rst", "python/_autosummary/mlx.core.metal.is_available.rst", "python/_autosummary/mlx.core.metal.set_cache_limit.rst", "python/_autosummary/mlx.core.metal.set_memory_limit.rst", "python/_autosummary/mlx.core.metal.start_capture.rst", "python/_autosummary/mlx.core.metal.stop_capture.rst", "python/_autosummary/mlx.core.min.rst", "python/_autosummary/mlx.core.minimum.rst", "python/_autosummary/mlx.core.moveaxis.rst", "python/_autosummary/mlx.core.multiply.rst", "python/_autosummary/mlx.core.negative.rst", "python/_autosummary/mlx.core.new_stream.rst", "python/_autosummary/mlx.core.ones.rst", "python/_autosummary/mlx.core.ones_like.rst", "python/_autosummary/mlx.core.outer.rst", "python/_autosummary/mlx.core.pad.rst", "python/_autosummary/mlx.core.partition.rst", "python/_autosummary/mlx.core.prod.rst", "python/_autosummary/mlx.core.quantize.rst", "python/_autosummary/mlx.core.quantized_matmul.rst", "python/_autosummary/mlx.core.random.bernoulli.rst", "python/_autosummary/mlx.core.random.categorical.rst", "python/_autosummary/mlx.core.random.gumbel.rst", "python/_autosummary/mlx.core.random.key.rst", "python/_autosummary/mlx.core.random.multivariate_normal.rst", "python/_autosummary/mlx.core.random.normal.rst", "python/_autosummary/mlx.core.random.randint.rst", "python/_autosummary/mlx.core.random.seed.rst", "python/_autosummary/mlx.core.random.split.rst", "python/_autosummary/mlx.core.random.truncated_normal.rst", "python/_autosummary/mlx.core.random.uniform.rst", "python/_autosummary/mlx.core.reciprocal.rst", "python/_autosummary/mlx.core.repeat.rst", "python/_autosummary/mlx.core.reshape.rst", "python/_autosummary/mlx.core.round.rst", "python/_autosummary/mlx.core.rsqrt.rst", "python/_autosummary/mlx.core.save.rst", "python/_autosummary/mlx.core.save_gguf.rst", "python/_autosummary/mlx.core.save_safetensors.rst", "python/_autosummary/mlx.core.savez.rst", "python/_autosummary/mlx.core.savez_compressed.rst", "python/_autosummary/mlx.core.set_default_device.rst", "python/_autosummary/mlx.core.set_default_stream.rst", "python/_autosummary/mlx.core.sigmoid.rst", "python/_autosummary/mlx.core.sign.rst", "python/_autosummary/mlx.core.sin.rst", "python/_autosummary/mlx.core.sinh.rst", "python/_autosummary/mlx.core.softmax.rst", "python/_autosummary/mlx.core.sort.rst", "python/_autosummary/mlx.core.split.rst", "python/_autosummary/mlx.core.sqrt.rst", "python/_autosummary/mlx.core.square.rst", "python/_autosummary/mlx.core.squeeze.rst", "python/_autosummary/mlx.core.stack.rst", "python/_autosummary/mlx.core.std.rst", "python/_autosummary/mlx.core.stop_gradient.rst", "python/_autosummary/mlx.core.stream.rst", "python/_autosummary/mlx.core.subtract.rst", "python/_autosummary/mlx.core.sum.rst", "python/_autosummary/mlx.core.swapaxes.rst", "python/_autosummary/mlx.core.take.rst", "python/_autosummary/mlx.core.take_along_axis.rst", "python/_autosummary/mlx.core.tan.rst", "python/_autosummary/mlx.core.tanh.rst", "python/_autosummary/mlx.core.tensordot.rst", "python/_autosummary/mlx.core.tile.rst", "python/_autosummary/mlx.core.topk.rst", "python/_autosummary/mlx.core.transpose.rst", "python/_autosummary/mlx.core.tri.rst", "python/_autosummary/mlx.core.tril.rst", "python/_autosummary/mlx.core.triu.rst", "python/_autosummary/mlx.core.value_and_grad.rst", "python/_autosummary/mlx.core.var.rst", "python/_autosummary/mlx.core.vjp.rst", "python/_autosummary/mlx.core.vmap.rst", "python/_autosummary/mlx.core.where.rst", "python/_autosummary/mlx.core.zeros.rst", "python/_autosummary/mlx.core.zeros_like.rst", "python/_autosummary/mlx.nn.value_and_grad.rst", "python/_autosummary/mlx.utils.tree_flatten.rst", "python/_autosummary/mlx.utils.tree_map.rst", "python/_autosummary/mlx.utils.tree_unflatten.rst", "python/_autosummary/stream_class.rst", "python/array.rst", "python/data_types.rst", "python/devices_and_streams.rst", "python/fast.rst", "python/fft.rst", "python/linalg.rst", "python/metal.rst", "python/nn.rst", "python/nn/_autosummary/mlx.nn.ALiBi.rst", "python/nn/_autosummary/mlx.nn.AvgPool1d.rst", "python/nn/_autosummary/mlx.nn.AvgPool2d.rst", "python/nn/_autosummary/mlx.nn.BatchNorm.rst", "python/nn/_autosummary/mlx.nn.Conv1d.rst", "python/nn/_autosummary/mlx.nn.Conv2d.rst", "python/nn/_autosummary/mlx.nn.Dropout.rst", "python/nn/_autosummary/mlx.nn.Dropout2d.rst", "python/nn/_autosummary/mlx.nn.Dropout3d.rst", "python/nn/_autosummary/mlx.nn.Embedding.rst", "python/nn/_autosummary/mlx.nn.GELU.rst", "python/nn/_autosummary/mlx.nn.GRU.rst", "python/nn/_autosummary/mlx.nn.GroupNorm.rst", "python/nn/_autosummary/mlx.nn.InstanceNorm.rst", "python/nn/_autosummary/mlx.nn.LSTM.rst", "python/nn/_autosummary/mlx.nn.LayerNorm.rst", "python/nn/_autosummary/mlx.nn.Linear.rst", "python/nn/_autosummary/mlx.nn.MaxPool1d.rst", "python/nn/_autosummary/mlx.nn.MaxPool2d.rst", "python/nn/_autosummary/mlx.nn.Mish.rst", "python/nn/_autosummary/mlx.nn.Module.apply.rst", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules.rst", "python/nn/_autosummary/mlx.nn.Module.children.rst", "python/nn/_autosummary/mlx.nn.Module.eval.rst", "python/nn/_autosummary/mlx.nn.Module.filter_and_map.rst", "python/nn/_autosummary/mlx.nn.Module.freeze.rst", "python/nn/_autosummary/mlx.nn.Module.leaf_modules.rst", "python/nn/_autosummary/mlx.nn.Module.load_weights.rst", "python/nn/_autosummary/mlx.nn.Module.modules.rst", "python/nn/_autosummary/mlx.nn.Module.named_modules.rst", "python/nn/_autosummary/mlx.nn.Module.parameters.rst", "python/nn/_autosummary/mlx.nn.Module.save_weights.rst", "python/nn/_autosummary/mlx.nn.Module.set_dtype.rst", "python/nn/_autosummary/mlx.nn.Module.state.rst", "python/nn/_autosummary/mlx.nn.Module.train.rst", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters.rst", "python/nn/_autosummary/mlx.nn.Module.training.rst", "python/nn/_autosummary/mlx.nn.Module.unfreeze.rst", "python/nn/_autosummary/mlx.nn.Module.update.rst", "python/nn/_autosummary/mlx.nn.Module.update_modules.rst", "python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst", "python/nn/_autosummary/mlx.nn.PReLU.rst", "python/nn/_autosummary/mlx.nn.QuantizedLinear.rst", "python/nn/_autosummary/mlx.nn.RMSNorm.rst", "python/nn/_autosummary/mlx.nn.RNN.rst", "python/nn/_autosummary/mlx.nn.ReLU.rst", "python/nn/_autosummary/mlx.nn.RoPE.rst", "python/nn/_autosummary/mlx.nn.SELU.rst", "python/nn/_autosummary/mlx.nn.Sequential.rst", "python/nn/_autosummary/mlx.nn.SiLU.rst", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst", "python/nn/_autosummary/mlx.nn.Softshrink.rst", "python/nn/_autosummary/mlx.nn.Step.rst", "python/nn/_autosummary/mlx.nn.Transformer.rst", "python/nn/_autosummary/mlx.nn.Upsample.rst", "python/nn/_autosummary/mlx.nn.init.constant.rst", "python/nn/_autosummary/mlx.nn.init.glorot_normal.rst", "python/nn/_autosummary/mlx.nn.init.glorot_uniform.rst", "python/nn/_autosummary/mlx.nn.init.he_normal.rst", "python/nn/_autosummary/mlx.nn.init.he_uniform.rst", "python/nn/_autosummary/mlx.nn.init.identity.rst", "python/nn/_autosummary/mlx.nn.init.normal.rst", "python/nn/_autosummary/mlx.nn.init.uniform.rst", "python/nn/_autosummary_functions/mlx.nn.elu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst", "python/nn/_autosummary_functions/mlx.nn.glu.rst", "python/nn/_autosummary_functions/mlx.nn.hardswish.rst", "python/nn/_autosummary_functions/mlx.nn.leaky_relu.rst", "python/nn/_autosummary_functions/mlx.nn.log_sigmoid.rst", "python/nn/_autosummary_functions/mlx.nn.log_softmax.rst", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst", "python/nn/_autosummary_functions/mlx.nn.mish.rst", "python/nn/_autosummary_functions/mlx.nn.prelu.rst", "python/nn/_autosummary_functions/mlx.nn.relu.rst", "python/nn/_autosummary_functions/mlx.nn.relu6.rst", "python/nn/_autosummary_functions/mlx.nn.selu.rst", "python/nn/_autosummary_functions/mlx.nn.sigmoid.rst", "python/nn/_autosummary_functions/mlx.nn.silu.rst", "python/nn/_autosummary_functions/mlx.nn.softmax.rst", "python/nn/_autosummary_functions/mlx.nn.softplus.rst", "python/nn/_autosummary_functions/mlx.nn.softshrink.rst", "python/nn/_autosummary_functions/mlx.nn.step.rst", "python/nn/_autosummary_functions/mlx.nn.tanh.rst", "python/nn/functions.rst", "python/nn/init.rst", "python/nn/layers.rst", "python/nn/losses.rst", "python/nn/module.rst", "python/ops.rst", "python/optimizers.rst", "python/optimizers/_autosummary/mlx.optimizers.AdaDelta.rst", "python/optimizers/_autosummary/mlx.optimizers.Adafactor.rst", "python/optimizers/_autosummary/mlx.optimizers.Adagrad.rst", "python/optimizers/_autosummary/mlx.optimizers.Adam.rst", "python/optimizers/_autosummary/mlx.optimizers.AdamW.rst", "python/optimizers/_autosummary/mlx.optimizers.Adamax.rst", "python/optimizers/_autosummary/mlx.optimizers.Lion.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.rst", "python/optimizers/_autosummary/mlx.optimizers.RMSprop.rst", "python/optimizers/_autosummary/mlx.optimizers.SGD.rst", "python/optimizers/_autosummary/mlx.optimizers.cosine_decay.rst", "python/optimizers/_autosummary/mlx.optimizers.exponential_decay.rst", "python/optimizers/_autosummary/mlx.optimizers.join_schedules.rst", "python/optimizers/_autosummary/mlx.optimizers.linear_schedule.rst", "python/optimizers/_autosummary/mlx.optimizers.step_decay.rst", "python/optimizers/common_optimizers.rst", "python/optimizers/optimizer.rst", "python/optimizers/schedulers.rst", "python/random.rst", "python/transforms.rst", "python/tree_utils.rst", "usage/compile.rst", "usage/function_transforms.rst", "usage/indexing.rst", "usage/lazy_evaluation.rst", "usage/numpy.rst", "usage/quick_start.rst", "usage/saving_and_loading.rst", "usage/unified_memory.rst", "usage/using_streams.rst"], "titles": ["Operations", "Developer Documentation", "Metal Debugger", "Linear Regression", "LLM inference", "Multi-Layer Perceptron", "MLX", "Build and Install", "mlx.core.Device", "mlx.core.Dtype", "mlx.core.DtypeCategory", "mlx.core.abs", "mlx.core.add", "mlx.core.all", "mlx.core.allclose", "mlx.core.any", "mlx.core.arange", "mlx.core.arccos", "mlx.core.arccosh", "mlx.core.arcsin", "mlx.core.arcsinh", "mlx.core.arctan", "mlx.core.arctanh", "mlx.core.argmax", "mlx.core.argmin", "mlx.core.argpartition", "mlx.core.argsort", "mlx.core.array", "mlx.core.array.T", "mlx.core.array.abs", "mlx.core.array.all", "mlx.core.array.any", "mlx.core.array.argmax", "mlx.core.array.argmin", "mlx.core.array.astype", "mlx.core.array.at", "mlx.core.array.cos", "mlx.core.array.cummax", "mlx.core.array.cummin", "mlx.core.array.cumprod", "mlx.core.array.cumsum", "mlx.core.array.diag", "mlx.core.array.diagonal", "mlx.core.array.dtype", "mlx.core.array.exp", "mlx.core.array.flatten", "mlx.core.array.item", "mlx.core.array.itemsize", "mlx.core.array.log", "mlx.core.array.log10", "mlx.core.array.log1p", "mlx.core.array.log2", "mlx.core.array.logsumexp", "mlx.core.array.max", "mlx.core.array.mean", "mlx.core.array.min", "mlx.core.array.moveaxis", "mlx.core.array.nbytes", "mlx.core.array.ndim", "mlx.core.array.prod", "mlx.core.array.reciprocal", "mlx.core.array.reshape", "mlx.core.array.round", "mlx.core.array.rsqrt", "mlx.core.array.shape", "mlx.core.array.sin", "mlx.core.array.size", "mlx.core.array.split", "mlx.core.array.sqrt", "mlx.core.array.square", "mlx.core.array.squeeze", "mlx.core.array.sum", "mlx.core.array.swapaxes", "mlx.core.array.tolist", "mlx.core.array.transpose", "mlx.core.array.var", "mlx.core.array_equal", "mlx.core.atleast_1d", "mlx.core.atleast_2d", "mlx.core.atleast_3d", "mlx.core.broadcast_to", "mlx.core.ceil", "mlx.core.clip", "mlx.core.compile", "mlx.core.concatenate", "mlx.core.conv1d", "mlx.core.conv2d", "mlx.core.conv_general", "mlx.core.convolve", "mlx.core.cos", "mlx.core.cosh", "mlx.core.cummax", "mlx.core.cummin", "mlx.core.cumprod", "mlx.core.cumsum", "mlx.core.default_device", "mlx.core.default_stream", "mlx.core.dequantize", "mlx.core.diag", "mlx.core.diagonal", "mlx.core.disable_compile", "mlx.core.divide", "mlx.core.divmod", "mlx.core.enable_compile", "mlx.core.equal", "mlx.core.erf", "mlx.core.erfinv", "mlx.core.eval", "mlx.core.exp", "mlx.core.expand_dims", "mlx.core.expm1", "mlx.core.eye", "mlx.core.fast.layer_norm", "mlx.core.fast.rms_norm", "mlx.core.fast.rope", "mlx.core.fast.scaled_dot_product_attention", "mlx.core.fft.fft", "mlx.core.fft.fft2", "mlx.core.fft.fftn", "mlx.core.fft.ifft", "mlx.core.fft.ifft2", "mlx.core.fft.ifftn", "mlx.core.fft.irfft", "mlx.core.fft.irfft2", "mlx.core.fft.irfftn", "mlx.core.fft.rfft", "mlx.core.fft.rfft2", "mlx.core.fft.rfftn", "mlx.core.flatten", "mlx.core.floor", "mlx.core.floor_divide", "mlx.core.full", "mlx.core.grad", "mlx.core.greater", "mlx.core.greater_equal", "mlx.core.identity", "mlx.core.inner", "mlx.core.isclose", "mlx.core.isinf", "mlx.core.isnan", "mlx.core.isneginf", "mlx.core.isposinf", "mlx.core.issubdtype", "mlx.core.jvp", "mlx.core.less", "mlx.core.less_equal", "mlx.core.linalg.norm", "mlx.core.linalg.qr", "mlx.core.linspace", "mlx.core.load", "mlx.core.log", "mlx.core.log10", "mlx.core.log1p", "mlx.core.log2", "mlx.core.logaddexp", "mlx.core.logical_and", "mlx.core.logical_not", "mlx.core.logical_or", "mlx.core.logsumexp", "mlx.core.matmul", "mlx.core.max", "mlx.core.maximum", "mlx.core.mean", "mlx.core.meshgrid", "mlx.core.metal.get_active_memory", "mlx.core.metal.get_cache_memory", "mlx.core.metal.get_peak_memory", "mlx.core.metal.is_available", "mlx.core.metal.set_cache_limit", "mlx.core.metal.set_memory_limit", "mlx.core.metal.start_capture", "mlx.core.metal.stop_capture", "mlx.core.min", "mlx.core.minimum", "mlx.core.moveaxis", "mlx.core.multiply", "mlx.core.negative", "mlx.core.new_stream", "mlx.core.ones", "mlx.core.ones_like", "mlx.core.outer", "mlx.core.pad", "mlx.core.partition", "mlx.core.prod", "mlx.core.quantize", "mlx.core.quantized_matmul", "mlx.core.random.bernoulli", "mlx.core.random.categorical", "mlx.core.random.gumbel", "mlx.core.random.key", "mlx.core.random.multivariate_normal", "mlx.core.random.normal", "mlx.core.random.randint", "mlx.core.random.seed", "mlx.core.random.split", "mlx.core.random.truncated_normal", "mlx.core.random.uniform", "mlx.core.reciprocal", "mlx.core.repeat", "mlx.core.reshape", "mlx.core.round", "mlx.core.rsqrt", "mlx.core.save", "mlx.core.save_gguf", "mlx.core.save_safetensors", "mlx.core.savez", "mlx.core.savez_compressed", "mlx.core.set_default_device", "mlx.core.set_default_stream", "mlx.core.sigmoid", "mlx.core.sign", "mlx.core.sin", "mlx.core.sinh", "mlx.core.softmax", "mlx.core.sort", "mlx.core.split", "mlx.core.sqrt", "mlx.core.square", "mlx.core.squeeze", "mlx.core.stack", "mlx.core.std", "mlx.core.stop_gradient", "mlx.core.stream", "mlx.core.subtract", "mlx.core.sum", "mlx.core.swapaxes", "mlx.core.take", "mlx.core.take_along_axis", "mlx.core.tan", "mlx.core.tanh", "mlx.core.tensordot", "mlx.core.tile", "mlx.core.topk", "mlx.core.transpose", "mlx.core.tri", "mlx.core.tril", "mlx.core.triu", "mlx.core.value_and_grad", "mlx.core.var", "mlx.core.vjp", "mlx.core.vmap", "mlx.core.where", "mlx.core.zeros", "mlx.core.zeros_like", "mlx.nn.value_and_grad", "mlx.utils.tree_flatten", "mlx.utils.tree_map", "mlx.utils.tree_unflatten", "mlx.core.Stream", "Array", "Data Types", "Devices and Streams", "Fast", "FFT", "Linear Algebra", "Metal", "Neural Networks", "mlx.nn.ALiBi", "mlx.nn.AvgPool1d", "mlx.nn.AvgPool2d", "mlx.nn.BatchNorm", "mlx.nn.Conv1d", "mlx.nn.Conv2d", "mlx.nn.Dropout", "mlx.nn.Dropout2d", "mlx.nn.Dropout3d", "mlx.nn.Embedding", "mlx.nn.GELU", "mlx.nn.GRU", "mlx.nn.GroupNorm", "mlx.nn.InstanceNorm", "mlx.nn.LSTM", "mlx.nn.LayerNorm", "mlx.nn.Linear", "mlx.nn.MaxPool1d", "mlx.nn.MaxPool2d", "mlx.nn.Mish", "mlx.nn.Module.apply", "mlx.nn.Module.apply_to_modules", "mlx.nn.Module.children", "mlx.nn.Module.eval", "mlx.nn.Module.filter_and_map", "mlx.nn.Module.freeze", "mlx.nn.Module.leaf_modules", "mlx.nn.Module.load_weights", "mlx.nn.Module.modules", "mlx.nn.Module.named_modules", "mlx.nn.Module.parameters", "mlx.nn.Module.save_weights", "mlx.nn.Module.set_dtype", "mlx.nn.Module.state", "mlx.nn.Module.train", "mlx.nn.Module.trainable_parameters", "mlx.nn.Module.training", "mlx.nn.Module.unfreeze", "mlx.nn.Module.update", "mlx.nn.Module.update_modules", "mlx.nn.MultiHeadAttention", "mlx.nn.PReLU", "mlx.nn.QuantizedLinear", "mlx.nn.RMSNorm", "mlx.nn.RNN", "mlx.nn.ReLU", "mlx.nn.RoPE", "mlx.nn.SELU", "mlx.nn.Sequential", "mlx.nn.SiLU", "mlx.nn.SinusoidalPositionalEncoding", "mlx.nn.Softshrink", "mlx.nn.Step", "mlx.nn.Transformer", "mlx.nn.Upsample", "mlx.nn.init.constant", "mlx.nn.init.glorot_normal", "mlx.nn.init.glorot_uniform", "mlx.nn.init.he_normal", "mlx.nn.init.he_uniform", "mlx.nn.init.identity", "mlx.nn.init.normal", "mlx.nn.init.uniform", "mlx.nn.elu", "mlx.nn.gelu", "mlx.nn.gelu_approx", "mlx.nn.gelu_fast_approx", "mlx.nn.glu", "mlx.nn.hardswish", "mlx.nn.leaky_relu", "mlx.nn.log_sigmoid", "mlx.nn.log_softmax", "mlx.nn.losses.binary_cross_entropy", "mlx.nn.losses.cosine_similarity_loss", "mlx.nn.losses.cross_entropy", "mlx.nn.losses.gaussian_nll_loss", "mlx.nn.losses.hinge_loss", "mlx.nn.losses.huber_loss", "mlx.nn.losses.kl_div_loss", "mlx.nn.losses.l1_loss", "mlx.nn.losses.log_cosh_loss", "mlx.nn.losses.margin_ranking_loss", "mlx.nn.losses.mse_loss", "mlx.nn.losses.nll_loss", "mlx.nn.losses.smooth_l1_loss", "mlx.nn.losses.triplet_loss", "mlx.nn.mish", "mlx.nn.prelu", "mlx.nn.relu", "mlx.nn.relu6", "mlx.nn.selu", "mlx.nn.sigmoid", "mlx.nn.silu", "mlx.nn.softmax", "mlx.nn.softplus", "mlx.nn.softshrink", "mlx.nn.step", "mlx.nn.tanh", "Functions", "Initializers", "Layers", "Loss Functions", "Module", "Operations", "Optimizers", "mlx.optimizers.AdaDelta", "mlx.optimizers.Adafactor", "mlx.optimizers.Adagrad", "mlx.optimizers.Adam", "mlx.optimizers.AdamW", "mlx.optimizers.Adamax", "mlx.optimizers.Lion", "mlx.optimizers.Optimizer.apply_gradients", "mlx.optimizers.Optimizer.init", "mlx.optimizers.Optimizer.state", "mlx.optimizers.Optimizer.update", "mlx.optimizers.RMSprop", "mlx.optimizers.SGD", "mlx.optimizers.cosine_decay", "mlx.optimizers.exponential_decay", "mlx.optimizers.join_schedules", "mlx.optimizers.linear_schedule", "mlx.optimizers.step_decay", "Common Optimizers", "Optimizer", "Schedulers", "Random", "Transforms", "Tree Utils", "Compilation", "Function Transforms", "Indexing Arrays", "Lazy Evaluation", "Conversion to NumPy and Other Frameworks", "Quick Start Guide", "Saving and Loading Arrays", "Unified Memory", "Using Streams"], "terms": {"mlx": [1, 2, 3, 4, 5, 7, 256, 356, 359, 361, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393], "provid": [1, 4, 97, 132, 230, 237, 246, 256, 277, 282, 284, 294, 295, 296, 299, 310, 311, 355, 359, 392, 394], "open": [2, 7, 16, 192, 196], "flexibl": 6, "which": [1, 4, 5, 6, 7, 16, 34, 83, 87, 99, 107, 114, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 132, 138, 139, 140, 141, 143, 146, 147, 149, 163, 170, 184, 187, 188, 198, 199, 202, 203, 204, 205, 206, 218, 219, 226, 237, 239, 240, 259, 264, 265, 267, 275, 277, 281, 303, 331, 334, 338, 341, 356, 369, 370, 383, 386, 387, 388, 389, 393, 394], "user": [1, 4, 256], "mai": [1, 146, 264, 387, 388], "add": [1, 2, 4, 35, 109, 154, 181, 184, 261, 262, 387, 393], "special": 1, "without": [4, 6, 221, 297, 355, 385, 386, 389, 390, 393], "much": [1, 4, 258, 259, 274, 275, 386, 389], "hassl": [], "while": [1, 2, 4, 7, 199, 303, 389, 390], "librari": [1, 7, 256], "suppli": [], "effici": [4, 6, 264, 303, 389, 391], "can": [1, 2, 4, 6, 7, 12, 16, 61, 74, 83, 99, 100, 101, 102, 104, 107, 133, 134, 144, 145, 146, 154, 161, 173, 175, 186, 187, 192, 195, 196, 203, 223, 237, 256, 259, 266, 275, 281, 294, 305, 311, 331, 356, 359, 361, 369, 370, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394], "compos": [6, 256, 386, 387, 391], "ani": [1, 4, 6, 16, 83, 245, 246, 247, 256, 267, 277, 278, 281, 290, 299, 310, 311, 356, 378, 385, 386, 387, 389, 391, 392, 393], "number": [1, 10, 16, 57, 66, 83, 86, 87, 97, 111, 132, 135, 143, 148, 181, 184, 185, 187, 191, 194, 196, 198, 200, 230, 231, 234, 237, 239, 240, 256, 260, 261, 262, 264, 265, 269, 270, 297, 298, 310, 311, 313, 314, 315, 316, 375, 377, 378, 383, 386, 387, 394], "applic": [2, 7], "aris": 390, "case": [1, 4, 118, 121, 122, 124, 125, 126, 127, 128, 147, 159, 199, 218, 259, 264, 275, 309, 341, 347, 352, 353, 369, 370, 386, 387, 391, 392, 393, 394], "where": [5, 111, 137, 184, 237, 240, 258, 259, 260, 261, 262, 263, 264, 265, 267, 268, 269, 270, 271, 272, 273, 274, 275, 281, 298, 300, 301, 309, 315, 316, 320, 321, 322, 323, 332, 338, 344, 347, 349, 353, 370, 387, 388], "new": [1, 5, 80, 99, 174, 177, 199, 219, 233, 246, 289, 297, 359, 361, 372, 377, 386, 388, 389, 390], "function": [1, 2, 3, 4, 5, 6, 14, 83, 102, 105, 106, 132, 137, 143, 146, 147, 159, 209, 237, 239, 240, 244, 246, 256, 267, 276, 278, 282, 289, 294, 298, 301, 302, 304, 305, 306, 308, 309, 310, 321, 322, 323, 324, 325, 327, 328, 343, 348, 350, 351, 352, 353, 354, 356, 361, 370, 383, 385, 388, 389, 390, 392], "highli": 7, "optim": [2, 3, 5, 6, 295, 386, 387, 389], "ar": [1, 3, 4, 5, 6, 7, 14, 16, 76, 80, 82, 83, 87, 88, 99, 107, 111, 117, 118, 120, 121, 123, 124, 126, 127, 128, 132, 137, 138, 139, 140, 141, 142, 143, 146, 147, 149, 159, 169, 180, 181, 182, 184, 185, 186, 187, 188, 192, 195, 196, 205, 206, 218, 219, 226, 237, 239, 240, 245, 246, 250, 260, 261, 262, 263, 264, 265, 269, 270, 272, 273, 284, 297, 299, 311, 329, 331, 332, 355, 359, 368, 370, 385, 386, 387, 388, 389, 390, 391, 392, 393], "need": [1, 4, 5, 6, 76, 184, 256, 295, 296, 307, 310, 383, 387, 389, 390, 391, 393], "For": [1, 4, 7, 35, 115, 142, 146, 184, 247, 256, 260, 264, 277, 282, 291, 294, 299, 303, 307, 311, 313, 314, 315, 316, 356, 383, 386, 387, 388, 389, 390, 391, 392, 393], "you": [1, 2, 4, 5, 6, 7, 256, 307, 310, 356, 383, 386, 387, 388, 390, 392, 393], "design": [3, 6, 383, 393], "your": [1, 4, 7, 359, 387, 389], "own": [7, 390], "link": [1, 7], "top": [1, 232, 273, 311], "core": [1, 2, 3, 4, 5, 256, 258, 259, 260, 270, 274, 275, 284, 287, 289, 292, 311, 312, 313, 314, 315, 316, 317, 318, 319, 329, 331, 338, 356, 359, 361, 386, 390, 391], "we": [1, 3, 4, 5, 97, 184, 185, 256, 266, 305, 366, 368, 383, 385, 386, 387, 389, 393], "inner": 386, "work": [1, 2, 4, 7, 169, 386, 387, 388, 389], "go": [1, 4, 387], "over": [1, 4, 5, 13, 15, 23, 24, 25, 26, 85, 86, 87, 91, 92, 93, 94, 118, 121, 124, 127, 136, 146, 148, 158, 160, 162, 172, 182, 183, 201, 213, 214, 220, 224, 230, 232, 238, 260, 261, 262, 269, 272, 300, 331, 375, 378, 387], "simpl": [1, 4, 5, 256, 266, 355, 386, 387, 389], "learn": [3, 5, 6, 260, 269, 270, 272, 298, 300, 362, 363, 364, 365, 366, 367, 368, 373, 374], "step": [2, 4, 5, 16, 256, 268, 271, 301, 363, 370, 375, 377, 378, 379, 386], "involv": [361, 386], "ad": [1, 3, 7, 112, 270, 359, 362, 363, 364, 365, 366, 367, 373, 389, 392], "let": [1, 3, 4, 386, 387, 389, 390], "": [1, 3, 4, 5, 43, 47, 58, 83, 96, 97, 117, 118, 120, 121, 123, 124, 126, 127, 132, 146, 149, 162, 180, 184, 187, 200, 203, 204, 220, 222, 237, 238, 240, 244, 256, 259, 268, 271, 275, 281, 282, 284, 288, 289, 290, 294, 301, 361, 370, 371, 383, 386, 387, 389, 390, 391, 392, 393], "sai": [1, 4, 356, 389], "would": [1, 4, 311, 388, 389, 390, 393], "like": [1, 4, 6, 142, 179, 243, 265, 337, 370, 372, 386, 387, 389, 390, 391, 393], "an": [1, 2, 4, 5, 7, 9, 13, 15, 27, 77, 78, 79, 80, 85, 86, 87, 107, 111, 112, 115, 128, 131, 135, 146, 149, 169, 174, 178, 179, 181, 183, 184, 185, 198, 199, 200, 215, 218, 225, 226, 227, 230, 231, 234, 240, 242, 243, 245, 246, 256, 258, 259, 263, 269, 271, 272, 273, 274, 275, 277, 297, 298, 299, 301, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 322, 344, 356, 362, 372, 376, 381, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394], "take": [1, 4, 5, 83, 132, 143, 161, 173, 179, 185, 227, 237, 239, 240, 243, 297, 383, 387, 388, 392, 393, 394], "two": [1, 12, 14, 76, 78, 99, 101, 104, 117, 120, 126, 133, 134, 137, 144, 145, 147, 154, 159, 161, 173, 175, 180, 225, 259, 271, 275, 299, 324, 330, 386, 387, 388, 393], "arrai": [1, 4, 5, 6, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 172, 173, 174, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 256, 260, 271, 277, 284, 287, 292, 298, 311, 312, 313, 314, 315, 316, 317, 318, 319, 321, 324, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 353, 356, 359, 362, 363, 364, 365, 366, 367, 368, 373, 374, 375, 376, 377, 378, 379, 386, 387, 389, 390, 391, 393], "x": [1, 3, 4, 5, 35, 105, 110, 112, 113, 135, 146, 185, 188, 200, 205, 209, 235, 236, 241, 246, 256, 258, 259, 260, 267, 269, 270, 272, 273, 274, 275, 276, 277, 298, 300, 302, 307, 309, 311, 320, 321, 322, 323, 324, 325, 326, 327, 328, 341, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 359, 361, 368, 386, 387, 388, 389, 390, 391, 393], "y": [1, 3, 4, 5, 35, 241, 256, 260, 264, 269, 270, 272, 273, 300, 333, 338, 341, 361, 364, 386, 387, 389, 390], "scale": [1, 4, 97, 112, 113, 114, 115, 184, 185, 191, 264, 265, 272, 297, 303, 304, 307, 311, 347, 363], "them": [1, 4, 256, 282, 294, 393], "both": [1, 12, 101, 102, 104, 133, 134, 142, 144, 145, 146, 154, 161, 173, 175, 187, 223, 258, 259, 270, 271, 274, 275, 361, 386, 387, 391, 393], "some": [1, 3, 4, 5, 282, 294, 370, 386, 387, 389], "coeffici": [1, 362, 363, 365, 366, 367, 368], "alpha": [1, 184, 320, 342, 344, 347, 366, 373], "beta": [1, 97, 184, 260, 269, 270, 272, 341, 365, 366, 367, 368], "respect": [1, 3, 5, 112, 113, 132, 184, 237, 246, 256, 260, 267, 269, 270, 272, 359, 387, 391], "togeth": [1, 5, 184, 246], "get": [1, 3, 5, 7, 86, 87, 95, 96, 164, 165, 166, 189, 256, 386, 387, 389, 393], "z": [1, 268, 386, 389], "well": [4, 256, 282, 294, 297, 389], "veri": [4, 297, 389, 393], "easili": [], "do": [1, 4, 7, 256, 283, 294, 356, 359, 366, 386, 387, 389], "just": [1, 5, 272, 386, 388], "write": [1, 4, 256, 390], "out": [1, 7, 258, 259, 264, 265, 274, 275, 291, 386, 387, 388], "follow": [1, 4, 5, 6, 7, 16, 88, 97, 146, 184, 256, 322, 323, 335, 362, 363, 364, 365, 366, 367, 368, 374, 383, 386, 387, 393], "import": [1, 2, 3, 4, 5, 7, 146, 205, 237, 245, 246, 247, 256, 258, 259, 260, 270, 274, 275, 284, 311, 329, 331, 338, 356, 359, 386, 387, 388, 389, 390, 391], "mx": [1, 2, 3, 4, 5, 35, 128, 142, 146, 147, 149, 205, 237, 256, 258, 259, 260, 270, 274, 275, 277, 284, 288, 302, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 326, 329, 330, 331, 335, 338, 345, 354, 356, 359, 361, 383, 386, 387, 388, 389, 390, 391, 392, 393, 394], "def": [1, 3, 4, 5, 237, 256, 359, 386, 387, 388, 389, 390, 393], "simple_axpbi": 1, "float": [1, 10, 14, 16, 73, 112, 113, 114, 115, 130, 131, 137, 142, 146, 185, 186, 191, 250, 260, 263, 264, 265, 269, 270, 272, 277, 289, 300, 303, 307, 309, 310, 311, 312, 313, 314, 315, 316, 318, 319, 330, 331, 332, 334, 338, 341, 342, 352, 353, 362, 363, 364, 365, 366, 367, 368, 373, 374, 375, 376, 378, 379], "return": [1, 3, 4, 5, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 46, 64, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 165, 168, 169, 170, 172, 173, 174, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 194, 195, 196, 197, 198, 199, 200, 201, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 256, 268, 271, 277, 278, 279, 281, 282, 283, 284, 285, 286, 287, 291, 292, 294, 295, 296, 299, 301, 312, 313, 314, 315, 316, 317, 318, 319, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 356, 359, 369, 385, 386, 387, 388, 389, 390, 392, 393], "thi": [1, 4, 5, 7, 13, 14, 15, 16, 23, 24, 25, 26, 103, 137, 143, 146, 147, 154, 158, 159, 160, 162, 164, 172, 182, 183, 187, 208, 213, 214, 215, 220, 224, 226, 232, 238, 256, 263, 264, 265, 268, 271, 278, 279, 281, 282, 285, 286, 287, 292, 294, 295, 296, 297, 299, 301, 309, 313, 314, 315, 316, 322, 323, 324, 337, 353, 359, 370, 385, 386, 387, 389, 390, 392], "perform": [1, 2, 4, 6, 87, 91, 92, 93, 94, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 159, 185, 200, 213, 226, 256, 269, 310, 315, 316, 386, 388, 389, 393], "leav": [1, 107, 246], "differenti": [1, 6], "howev": [1, 256, 267, 269, 370, 383, 386, 389, 390], "vector": [1, 3, 6, 136, 143, 146, 226, 239, 240, 266, 331, 391], "math": [4, 342, 386], "often": 265, "realiz": [], "axpbi": 1, "routin": 1, "defin": [1, 3, 4, 5, 7, 146, 185, 245, 390], "same": [1, 4, 7, 14, 35, 76, 80, 83, 86, 87, 88, 112, 113, 122, 125, 126, 127, 132, 137, 143, 181, 187, 200, 239, 241, 256, 259, 260, 263, 269, 270, 275, 299, 312, 313, 314, 315, 316, 317, 318, 319, 331, 342, 359, 369, 383, 386, 388, 393], "realli": 272, "part": [1, 387, 388], "doe": [1, 2, 4, 7, 164, 256, 386, 388, 389, 390], "fast": [6, 267, 323, 393], "so": [1, 4, 7, 132, 237, 263, 311, 361, 386, 389, 393], "decid": [246, 281], "want": [4, 387, 393], "reli": 1, "acceler": [1, 260], "framework": [1, 6], "continu": 387, "impos": [], "our": [1, 4, 5, 305, 362, 363, 364, 365, 367, 368], "assumpt": [], "also": [1, 4, 5, 6, 7, 10, 12, 100, 101, 102, 104, 118, 121, 124, 127, 133, 134, 144, 145, 154, 161, 173, 175, 184, 223, 244, 256, 281, 295, 297, 299, 306, 321, 347, 349, 355, 361, 386, 387, 388, 389, 390, 391, 394], "assum": [1, 4, 147, 246, 256, 258, 259, 269, 274, 275], "how": [1, 4, 5, 256, 258, 259, 261, 262, 266, 274, 275, 311, 369, 386, 388, 393], "gradient": [3, 5, 132, 221, 237, 244, 256, 282, 295, 299, 310, 337, 359, 361, 362, 363, 365, 366, 367, 368, 369, 372, 374, 386, 387, 388, 389, 390, 391], "ins": [], "what": [1, 4, 246], "coincid": [], "right": [1, 7, 184, 258, 259, 267, 274, 275, 311, 322, 323, 332, 334, 342], "place": [1, 4, 35, 200, 389, 390], "cours": 387, "The": [1, 2, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 43, 47, 57, 58, 64, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 133, 134, 135, 136, 137, 138, 139, 140, 141, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 168, 169, 170, 172, 173, 174, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 194, 195, 196, 197, 198, 199, 203, 204, 209, 210, 211, 212, 213, 214, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 250, 258, 259, 260, 261, 262, 263, 264, 265, 266, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 282, 284, 288, 289, 290, 291, 294, 295, 296, 297, 299, 300, 301, 303, 305, 307, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 324, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 353, 356, 359, 361, 362, 363, 364, 365, 366, 367, 368, 371, 373, 374, 375, 378, 381, 386, 387, 388, 389, 390, 391, 392, 393, 394], "structur": [1, 369, 387], "from": [1, 4, 5, 6, 97, 99, 123, 124, 126, 127, 131, 146, 149, 159, 163, 166, 168, 179, 184, 186, 187, 188, 189, 192, 195, 205, 218, 221, 223, 226, 227, 232, 241, 243, 245, 246, 247, 256, 273, 282, 284, 297, 313, 314, 315, 316, 318, 319, 332, 341, 356, 385, 386, 387, 389, 390, 391, 392, 393], "frontend": [], "api": [1, 387], "redirect": 1, "when": [1, 4, 6, 7, 83, 87, 146, 149, 261, 262, 311, 315, 316, 335, 341, 359, 377, 383, 386, 393], "appropri": [1, 386], "fallback": 1, "metal": [1, 6], "vjp": [1, 391], "jvp": [1, 391], "In": [1, 4, 5, 35, 159, 184, 246, 256, 264, 269, 359, 362, 364, 365, 367, 368, 369, 385, 386, 387, 389, 392, 393], "one": [1, 4, 7, 35, 73, 77, 82, 86, 87, 109, 111, 112, 113, 146, 152, 159, 185, 187, 218, 223, 250, 294, 311, 331, 393], "sentenc": [], "comput": [1, 3, 4, 5, 6, 7, 91, 92, 93, 94, 97, 110, 114, 132, 143, 146, 154, 162, 180, 184, 213, 220, 221, 230, 237, 238, 239, 244, 256, 260, 268, 269, 270, 271, 272, 282, 295, 299, 300, 303, 310, 313, 314, 315, 316, 322, 323, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 361, 362, 363, 365, 366, 367, 368, 372, 386, 387, 391, 393], "graph": [1, 4, 5, 6, 387], "rule": 1, "evalu": [1, 4, 5, 6, 107, 143, 239, 256, 280, 291, 359, 361, 386, 391], "said": 4, "start": [1, 3, 4, 6, 7, 16, 114, 148, 170, 215, 386, 388, 393], "discuss": 1, "more": [1, 2, 5, 9, 73, 99, 159, 168, 169, 203, 204, 250, 256, 260, 264, 303, 307, 310, 311, 313, 314, 315, 316, 383, 386, 387, 388, 391, 393], "detail": [1, 9, 168, 256, 264, 303, 307, 311, 313, 314, 315, 316, 362, 364, 365, 367, 368, 388, 391], "thei": [1, 3, 4, 14, 88, 137, 305, 333, 359, 368, 385, 386, 389, 391, 392, 393], "c": [1, 4, 146, 258, 259, 260, 261, 262, 264, 265, 270, 271, 274, 275, 390, 391, 393], "scalar": [1, 12, 14, 27, 46, 73, 76, 80, 82, 101, 102, 104, 130, 131, 132, 133, 134, 137, 144, 145, 146, 148, 154, 155, 156, 157, 159, 161, 173, 175, 181, 186, 192, 195, 196, 203, 223, 237, 241, 244, 342, 387, 389, 391], "i": [1, 2, 4, 5, 6, 7, 14, 16, 25, 34, 73, 82, 85, 86, 87, 88, 91, 92, 93, 94, 98, 99, 102, 107, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 130, 131, 137, 142, 143, 146, 147, 149, 154, 158, 159, 163, 166, 167, 169, 181, 182, 184, 185, 186, 187, 190, 191, 194, 195, 196, 199, 202, 203, 204, 209, 213, 215, 220, 221, 226, 227, 230, 233, 237, 238, 239, 240, 241, 245, 246, 250, 256, 258, 259, 260, 261, 262, 263, 264, 265, 267, 268, 269, 270, 271, 272, 273, 274, 275, 281, 282, 288, 290, 291, 293, 294, 296, 297, 298, 299, 300, 301, 303, 307, 309, 310, 311, 315, 316, 321, 322, 323, 329, 330, 332, 337, 338, 341, 342, 344, 349, 353, 359, 363, 366, 368, 369, 370, 375, 377, 378, 383, 386, 387, 388, 389, 390, 391, 392, 393, 394], "sum": [1, 3, 12, 94, 136, 146, 158, 213, 230, 256, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 388, 390], "element": [1, 11, 12, 17, 18, 19, 20, 21, 22, 25, 66, 81, 89, 90, 91, 92, 93, 94, 97, 101, 102, 104, 105, 106, 108, 110, 111, 129, 130, 133, 134, 137, 138, 139, 140, 141, 144, 145, 150, 151, 152, 153, 154, 155, 156, 157, 161, 163, 173, 175, 176, 182, 184, 185, 197, 198, 201, 209, 210, 211, 212, 216, 217, 223, 226, 228, 229, 232, 237, 241, 263, 264, 265, 268, 271, 276, 298, 301, 303, 325, 327, 328, 343, 344, 346, 349, 350, 351, 386, 387], "wise": [1, 11, 12, 17, 18, 19, 20, 21, 22, 81, 89, 90, 101, 102, 104, 105, 106, 108, 110, 129, 130, 133, 134, 137, 144, 145, 150, 151, 152, 153, 154, 155, 156, 157, 161, 173, 175, 176, 197, 201, 209, 210, 211, 212, 216, 217, 223, 228, 229, 264, 265, 276, 298, 325, 327, 328, 343, 344, 346, 349, 350, 351, 386], "numpi": [1, 4, 5, 6, 12, 14, 16, 80, 101, 102, 104, 133, 134, 137, 144, 145, 154, 159, 161, 173, 175, 223, 389, 391, 392], "style": [1, 12, 14, 101, 102, 104, 133, 134, 137, 144, 145, 154, 159, 161, 173, 175, 223], "broadcast": [1, 12, 14, 80, 82, 101, 102, 104, 131, 133, 134, 137, 144, 145, 154, 159, 161, 173, 175, 186, 187, 190, 195, 196, 223, 227, 241, 297], "between": [1, 6, 82, 128, 310, 330, 333, 334, 337, 377, 389, 393], "input": [1, 3, 4, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 136, 137, 138, 139, 140, 141, 143, 144, 145, 146, 147, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 172, 173, 174, 175, 176, 179, 180, 181, 182, 183, 184, 185, 194, 197, 198, 199, 200, 201, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 235, 236, 237, 238, 240, 241, 243, 258, 259, 260, 261, 262, 264, 265, 266, 268, 269, 270, 271, 272, 273, 274, 275, 297, 299, 300, 301, 303, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 324, 329, 330, 332, 333, 334, 335, 337, 338, 340, 342, 353, 356, 386, 387, 388, 391, 392], "upcast": 1, "const": [1, 332], "factor": [1, 147, 311, 331, 376, 379], "streamordevic": 1, "stream": [1, 6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 96, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 133, 134, 135, 136, 137, 138, 139, 140, 141, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 191, 192, 194, 195, 196, 197, 198, 199, 200, 201, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 238, 241, 242, 243, 393], "schedul": [1, 169, 361, 375, 376, 377, 378, 379, 381, 393], "itself": [1, 370], "call": [1, 2, 4, 5, 28, 130, 256, 266, 282, 294, 305, 359, 361, 370, 386, 387, 389], "other": [1, 4, 6, 142, 146, 256, 283, 359, 368, 386, 388, 389, 391], "within": [2, 25, 137], "simplest": [1, 256], "wai": [1, 4, 7, 256, 311, 386, 387, 388], "about": [1, 4, 5, 389, 393], "term": [1, 332, 362, 363, 364, 365, 366, 367, 373], "exist": [1, 2, 4, 282, 294], "auto": [1, 7], "ax": [1, 13, 15, 23, 24, 74, 109, 117, 118, 120, 121, 123, 124, 126, 127, 128, 136, 146, 158, 160, 162, 172, 181, 183, 213, 218, 220, 224, 225, 230, 233, 238, 387], "multipli": [1, 35, 184, 185, 263, 307, 311], "earlier": 1, "goal": [], "themselv": [1, 386], "contain": [1, 4, 25, 26, 64, 83, 99, 122, 123, 124, 146, 155, 156, 157, 184, 215, 241, 256, 281, 283, 284, 290, 310, 338, 356, 359, 386, 387], "act": [1, 337], "data": [1, 5, 6, 9, 16, 111, 125, 126, 131, 135, 148, 178, 195, 234, 242, 265, 312, 313, 314, 315, 316, 317, 318, 319, 386, 388, 390], "nor": [1, 132, 237], "rather": [1, 387, 393], "easi": [1, 256], "interfac": 1, "block": [1, 4, 310], "A": [1, 4, 6, 7, 8, 64, 76, 83, 112, 113, 115, 132, 143, 146, 147, 149, 158, 159, 160, 172, 184, 186, 187, 188, 190, 191, 192, 195, 196, 215, 219, 222, 237, 239, 240, 244, 245, 246, 247, 248, 256, 260, 264, 268, 269, 270, 272, 281, 285, 286, 289, 295, 296, 300, 305, 307, 310, 313, 314, 316, 323, 342, 343, 359, 361, 365, 367, 369, 370, 372, 377, 386, 387, 389, 390], "It": [1, 4, 7, 132, 208, 237, 256, 296, 299, 369, 381, 390, 392], "creat": [1, 4, 7, 111, 135, 222, 256, 359, 361, 377, 386, 388, 390], "output": [1, 4, 7, 13, 14, 15, 16, 25, 80, 83, 91, 92, 93, 94, 111, 112, 113, 114, 115, 122, 125, 126, 127, 131, 132, 135, 137, 146, 148, 158, 160, 162, 163, 172, 178, 179, 182, 183, 186, 187, 188, 190, 191, 192, 195, 196, 205, 206, 213, 218, 220, 224, 227, 234, 237, 238, 239, 240, 241, 242, 243, 258, 259, 260, 261, 262, 270, 273, 274, 275, 297, 299, 309, 310, 311, 313, 314, 315, 316, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 353, 356, 386, 387, 388, 389, 390, 391, 392, 393], "given": [1, 13, 15, 25, 35, 80, 82, 84, 91, 92, 93, 94, 97, 99, 107, 109, 116, 117, 118, 119, 120, 121, 125, 126, 127, 131, 146, 158, 160, 162, 168, 172, 177, 183, 190, 192, 200, 208, 213, 215, 220, 224, 231, 232, 234, 235, 236, 238, 248, 258, 259, 263, 274, 275, 281, 297, 330, 332, 338], "set": [1, 4, 5, 7, 83, 100, 103, 112, 114, 168, 169, 207, 208, 222, 267, 272, 273, 280, 282, 289, 290, 291, 294, 295, 299, 303, 309, 330, 342, 353, 359, 363, 370, 383, 387, 389], "further": [1, 7, 387], "class": [1, 4, 5, 8, 9, 10, 27, 248, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 331, 359, 362, 363, 364, 365, 366, 367, 368, 373, 374, 381], "under": [1, 146], "These": [1, 83, 227, 331, 393], "word": [], "bit": [97, 184, 185, 250, 277, 299, 300], "abstract": [], "back": [4, 167, 390], "give": [1, 4, 5, 25, 386], "ourselv": [], "concret": [1, 268, 271, 273, 301, 389, 393], "imag": [262, 264, 265, 311], "public": [1, 256], "explicit": [1, 370, 383, 390], "alpha_": 1, "beta_": 1, "must": [1, 2, 7, 82, 131, 146, 186, 187, 190, 192, 195, 196, 241, 311, 390], "know": [1, 4], "popul": 1, "To": [1, 2, 3, 4, 5, 7, 168, 256, 356, 386, 387, 391], "avoid": [1, 289, 386], "unnecessari": [1, 4], "alloc": [1, 165, 168, 169, 359], "respons": 1, "space": [1, 148, 340], "void": 1, "eval_cpu": 1, "std": [1, 318], "overrid": [1, 103], "eval_gpu": 1, "jacobian": [1, 143, 239, 391], "product": [1, 93, 136, 143, 159, 180, 183, 230, 239, 297, 391], "primal": [1, 143, 239], "tangent": [1, 21, 22, 143, 228, 229, 354], "int": [1, 4, 5, 8, 13, 15, 16, 23, 24, 25, 26, 30, 31, 32, 33, 37, 38, 39, 40, 41, 42, 45, 52, 53, 54, 55, 56, 59, 62, 64, 67, 70, 71, 72, 73, 75, 80, 84, 85, 86, 87, 91, 92, 93, 94, 97, 98, 99, 109, 111, 114, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 131, 132, 135, 142, 146, 148, 158, 160, 162, 164, 165, 166, 168, 169, 172, 174, 178, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 198, 199, 200, 213, 214, 215, 218, 219, 220, 224, 225, 226, 227, 230, 231, 232, 233, 234, 235, 236, 237, 238, 240, 242, 248, 256, 258, 259, 260, 261, 262, 266, 268, 269, 270, 271, 272, 273, 274, 275, 297, 299, 300, 301, 303, 307, 310, 324, 330, 331, 335, 340, 342, 359, 375, 377, 378, 379], "argnum": [1, 132, 237, 387], "cotan": 1, "across": [1, 269], "pair": [1, 181, 284, 303], "repres": [1, 4, 338, 342, 390], "axi": [1, 4, 5, 13, 15, 23, 24, 25, 26, 30, 31, 32, 33, 37, 38, 39, 40, 52, 53, 54, 55, 59, 67, 70, 71, 75, 84, 91, 92, 93, 94, 99, 109, 112, 113, 116, 119, 122, 123, 124, 125, 126, 127, 128, 146, 158, 160, 162, 172, 174, 181, 182, 183, 187, 198, 213, 214, 215, 218, 219, 220, 224, 225, 226, 227, 231, 232, 233, 238, 240, 258, 259, 274, 275, 301, 324, 328, 330, 331, 335, 340, 342, 350, 388], "correspond": [1, 13, 15, 73, 82, 97, 99, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 158, 160, 172, 183, 224, 230, 240, 246, 387], "dimens": [1, 4, 13, 15, 23, 24, 58, 64, 73, 77, 78, 79, 83, 86, 87, 99, 109, 114, 123, 124, 126, 127, 128, 136, 146, 147, 158, 159, 160, 162, 172, 183, 184, 187, 194, 220, 224, 227, 230, 233, 238, 260, 261, 262, 264, 265, 268, 269, 270, 271, 272, 297, 300, 301, 303, 310, 311, 324, 331, 386, 387], "vmap": [1, 387, 389, 391], "print": [1, 2, 3, 4, 5, 7, 245, 246, 247, 256, 383, 386, 387, 388, 389, 390, 391], "ostream": 1, "o": [1, 7, 115, 271], "equival": [1, 28, 61, 74, 102, 130, 226, 267, 296, 298, 299, 302, 304, 306, 308], "check": [1, 7, 76, 142, 167, 284, 387, 388], "bool": [1, 13, 14, 15, 23, 24, 30, 31, 32, 33, 37, 38, 39, 40, 52, 53, 54, 55, 59, 71, 73, 75, 76, 83, 87, 91, 92, 93, 94, 114, 137, 142, 146, 149, 158, 160, 162, 163, 167, 169, 170, 172, 183, 185, 220, 224, 238, 260, 261, 262, 268, 269, 270, 271, 272, 273, 277, 281, 282, 284, 289, 291, 294, 297, 299, 301, 303, 307, 310, 311, 329, 332, 363, 374], "is_equival": 1, "privat": 1, "fall": 1, "eval": [1, 2, 3, 4, 5, 256, 359, 361, 386, 387, 389, 391], "deriv": [1, 387, 389], "base": [1, 114, 146, 151, 153, 303, 310, 359, 361, 367, 381, 383, 386, 388], "abov": [1, 4, 184, 235, 256, 311, 366, 387, 388, 389, 393], "demonstr": 390, "treat": [1, 123, 124, 126, 127, 226, 311, 386], "paramet": [1, 3, 4, 5, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 101, 102, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 168, 169, 170, 172, 173, 174, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 281, 282, 284, 289, 290, 291, 294, 295, 296, 297, 298, 299, 300, 301, 303, 305, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 324, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 353, 355, 356, 359, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 372, 373, 374, 375, 376, 377, 378, 379, 381, 386, 387, 389], "produc": [1, 83, 297, 356], "through": [1, 221, 310, 368, 386, 387, 390], "construct": [1, 5, 41, 98, 131, 178, 231, 242], "its": [1, 7, 159, 182, 194, 234, 244, 247, 256, 299, 365, 366, 367, 390, 393], "type": [1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 34, 64, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 168, 169, 170, 172, 173, 174, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 194, 195, 196, 197, 198, 199, 200, 201, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 245, 256, 289, 310, 312, 313, 314, 315, 316, 317, 318, 319, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 386, 388], "shape": [1, 2, 4, 5, 61, 76, 80, 83, 85, 86, 87, 99, 115, 116, 119, 122, 125, 126, 127, 131, 143, 159, 178, 179, 186, 187, 188, 190, 191, 192, 195, 196, 199, 227, 239, 241, 242, 243, 256, 258, 259, 260, 261, 262, 264, 265, 268, 270, 271, 273, 274, 275, 284, 301, 312, 313, 314, 315, 316, 317, 318, 319, 331, 342, 361, 386, 387, 388, 391, 393], "pass": [1, 4, 5, 61, 74, 180, 181, 237, 244, 245, 246, 256, 282, 294, 295, 296, 299, 305, 386, 389], "re": [5, 7, 356], "now": [1, 4, 7, 299, 386, 390], "promot": 1, "dtype": [1, 4, 10, 16, 27, 34, 35, 73, 111, 128, 131, 135, 142, 146, 147, 148, 178, 188, 190, 191, 192, 195, 196, 234, 242, 250, 289, 311, 312, 313, 314, 315, 316, 317, 318, 319, 329, 331, 338, 375, 376, 377, 378, 379, 386, 387, 388, 390, 391, 392], "promoted_dtyp": 1, "promote_typ": 1, "float32": [1, 10, 16, 111, 115, 135, 142, 146, 147, 148, 178, 188, 190, 191, 195, 196, 234, 242, 250, 311, 312, 313, 314, 315, 316, 317, 318, 319, 329, 331, 338, 375, 376, 377, 378, 379, 386, 387, 388, 389, 390, 391, 392], "non": [1, 7, 163, 292, 301, 343, 359], "point": [1, 3, 4, 7, 130, 185, 250], "out_dtyp": 1, "is_floating_point": 1, "cast": [1, 34, 125, 126, 127, 149, 277, 289, 390], "up": [1, 4, 299, 386], "determin": [1, 99, 190, 250, 288, 392], "x_cast": 1, "astyp": [1, 4, 277, 390], "y_cast": 1, "broadcasted_input": 1, "broadcast_arrai": 1, "out_shap": 1, "0": [1, 3, 4, 5, 7, 8, 16, 35, 41, 42, 45, 62, 67, 75, 84, 85, 86, 87, 98, 99, 111, 115, 128, 132, 146, 147, 168, 181, 186, 191, 196, 198, 200, 215, 219, 220, 234, 235, 236, 237, 238, 240, 245, 256, 258, 259, 260, 261, 262, 263, 264, 265, 267, 269, 270, 272, 274, 275, 298, 302, 303, 307, 308, 309, 310, 312, 313, 314, 315, 316, 317, 318, 319, 320, 322, 323, 325, 326, 329, 331, 333, 334, 338, 341, 342, 344, 345, 346, 347, 352, 353, 356, 359, 362, 363, 365, 366, 367, 368, 370, 373, 374, 375, 376, 377, 378, 379, 383, 386, 387, 388, 389, 390, 391, 392], "unique_ptr": 1, "make_shar": 1, "to_stream": 1, "handl": [1, 256, 386], "resolv": 1, "No": [1, 4], "happen": [1, 4, 112, 310, 361, 386, 389], "alon": [1, 390], "effect": [264, 386, 389], "onli": [1, 4, 6, 7, 76, 85, 86, 87, 146, 184, 190, 256, 281, 282, 284, 289, 291, 294, 295, 296, 359, 386, 387, 392, 393], "execut": [1, 7, 77, 78, 79, 166, 390, 393], "depend": [1, 2, 3, 73, 146, 268, 271, 301, 388, 392, 393], "devic": [1, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 133, 134, 135, 136, 137, 138, 139, 140, 141, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 169, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 191, 192, 194, 195, 196, 197, 198, 199, 200, 201, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 238, 241, 242, 243, 248, 393, 394], "specifi": [1, 16, 34, 86, 87, 99, 123, 124, 131, 132, 146, 148, 174, 178, 187, 198, 225, 226, 227, 230, 233, 237, 240, 242, 260, 309, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 353, 387, 393], "memori": [1, 6, 164, 165, 166, 168, 169, 310, 359, 363, 386, 389, 390], "ha": [1, 2, 4, 5, 6, 73, 83, 99, 122, 123, 125, 126, 127, 132, 163, 165, 187, 260, 268, 271, 273, 301, 359, 361, 386, 388, 389, 391, 393], "been": [1, 4, 165, 389], "try": [1, 7], "naiv": [1, 387], "gener": [1, 2, 3, 10, 16, 87, 111, 123, 124, 148, 163, 186, 190, 191, 192, 195, 196, 310, 383, 386, 388, 389, 394], "version": [1, 7, 97, 154, 158, 184, 213, 240, 383, 387, 388], "declar": 1, "member": [1, 256, 287, 292], "method": [1, 4, 8, 9, 10, 27, 248, 256, 288, 359, 362, 363, 364, 365, 366, 367, 368, 370, 373, 374, 381], "each": [1, 64, 97, 107, 114, 142, 159, 163, 181, 184, 185, 187, 198, 205, 206, 215, 231, 233, 240, 241, 264, 265, 266, 268, 269, 271, 301, 303, 310, 329, 331, 383, 386, 389], "find": [1, 3, 7], "pointwis": [], "captur": [1, 2, 83, 170, 171, 256, 386], "templat": 1, "axpby_impl": 1, "typenam": 1, "t": [1, 4, 105, 115, 185, 237, 256, 258, 268, 271, 274, 301, 362, 363, 364, 365, 366, 367, 368, 373, 374, 386, 387, 393], "readi": 1, "fill": [1, 131, 179, 234, 243, 312, 313, 314, 315, 316, 318, 319], "malloc_or_wait": 1, "synchron": [1, 386], "avail": [1, 3, 4, 5, 7, 9, 167, 393], "There": [1, 256, 311, 386], "wait": [1, 4, 169], "here": [1, 4, 386, 387, 389, 392, 393], "request": 1, "pressur": 1, "condit": [1, 241, 393], "set_data": 1, "nbyte": 1, "collect": [1, 246, 385], "pointer": 1, "x_ptr": 1, "y_ptr": 1, "out_ptr": 1, "relev": 1, "static_cast": 1, "size_t": 1, "out_idx": 1, "size": [1, 4, 5, 47, 64, 86, 97, 109, 112, 113, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 131, 135, 142, 146, 165, 169, 184, 185, 187, 199, 215, 218, 256, 258, 259, 261, 262, 266, 270, 274, 275, 299, 311, 363, 389, 390], "map": [1, 5, 35, 149, 246, 266, 277], "linear": [1, 4, 5, 6, 246, 256, 267, 284, 299, 301, 302, 304, 306, 311, 320, 321, 322, 323, 324, 326, 345, 346, 347, 349, 356, 359, 370, 378, 386], "indic": [1, 14, 23, 24, 25, 26, 35, 132, 137, 138, 139, 140, 141, 215, 226, 227, 237, 291, 293, 331, 338, 377, 388], "offset": [1, 4, 42, 99, 112, 114], "x_offset": 1, "elem_to_loc": 1, "stride": [1, 85, 86, 87, 258, 259, 261, 262, 274, 275, 303, 388], "y_offset": 1, "contigu": 1, "regularli": 1, "default": [1, 7, 13, 14, 15, 16, 23, 24, 25, 26, 76, 83, 84, 85, 86, 87, 95, 96, 97, 98, 99, 111, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 132, 135, 137, 146, 147, 148, 149, 158, 160, 162, 163, 168, 169, 172, 178, 182, 183, 184, 185, 186, 187, 188, 190, 191, 192, 194, 195, 196, 198, 199, 200, 207, 208, 214, 215, 218, 219, 220, 222, 224, 230, 232, 233, 234, 235, 236, 237, 238, 240, 242, 250, 258, 259, 260, 261, 262, 268, 270, 271, 273, 274, 275, 277, 282, 284, 289, 291, 294, 297, 298, 299, 301, 303, 307, 308, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 324, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 359, 362, 363, 364, 365, 366, 367, 368, 373, 374, 375, 383, 385, 386, 387, 390, 392, 394], "row": [1, 111, 135, 184, 234], "major": 1, "henc": [1, 184, 386], "doesn": [1, 256], "addit": [1, 4, 12, 112, 113, 115, 149, 260, 269, 272, 297, 300, 359, 387], "abl": [1, 184], "all": [1, 2, 5, 7, 14, 25, 35, 77, 78, 79, 83, 86, 87, 111, 118, 121, 124, 127, 159, 181, 182, 218, 256, 277, 278, 282, 285, 286, 287, 292, 294, 297, 299, 307, 310, 311, 356, 359, 381, 383, 386, 388, 389, 391, 394], "incom": 1, "accordingli": 1, "dispatch": 1, "float16": [1, 10, 149, 250, 277, 389, 390], "bfloat16": [1, 10, 250, 390], "complex64": [1, 250], "throw": [1, 83], "error": [1, 7, 105, 106, 169, 215, 267, 299, 321, 322, 323, 337, 339, 387, 390], "encount": [1, 387], "unexpect": [1, 16], "regist": [1, 5], "op": [1, 180, 282, 389], "assert": 1, "2": [1, 3, 4, 5, 35, 86, 98, 99, 105, 117, 120, 122, 123, 124, 125, 126, 127, 128, 142, 146, 147, 153, 159, 184, 190, 194, 230, 234, 235, 236, 250, 256, 258, 259, 262, 267, 274, 275, 300, 307, 311, 312, 313, 314, 315, 316, 317, 318, 319, 322, 331, 332, 334, 341, 342, 356, 359, 362, 364, 365, 366, 370, 373, 386, 387, 388, 389, 390, 391, 392, 393], "1": [1, 2, 4, 5, 16, 25, 26, 35, 42, 45, 85, 86, 87, 98, 99, 110, 115, 116, 117, 119, 120, 122, 123, 124, 125, 126, 127, 128, 136, 142, 146, 147, 159, 163, 169, 180, 182, 184, 187, 190, 191, 196, 209, 214, 226, 232, 237, 250, 256, 258, 259, 260, 261, 262, 263, 264, 265, 267, 268, 269, 270, 271, 272, 273, 274, 275, 298, 300, 301, 303, 307, 309, 311, 313, 314, 315, 316, 317, 318, 319, 320, 322, 323, 324, 327, 328, 329, 330, 331, 332, 333, 334, 335, 337, 338, 340, 341, 342, 347, 348, 350, 351, 353, 356, 359, 361, 362, 363, 364, 365, 366, 367, 368, 370, 373, 374, 375, 376, 377, 378, 379, 386, 387, 388, 390, 391, 392, 393], "correct": [1, 7, 365, 366, 367, 388, 389], "els": [1, 4, 256, 282, 389], "float16_t": 1, "bfloat16_t": 1, "complex64_t": 1, "runtime_error": 1, "support": [1, 4, 6, 7, 14, 85, 86, 87, 115, 128, 137, 147, 149, 159, 184, 190, 387, 388, 390, 392], "have": [1, 4, 7, 14, 76, 77, 78, 79, 123, 124, 126, 127, 137, 159, 170, 187, 245, 271, 297, 305, 368, 370, 385, 386, 388, 389, 393], "rememb": [], "3": [1, 4, 7, 128, 142, 146, 147, 311, 314, 316, 325, 363, 368, 383, 386, 388, 390, 391], "complic": [], "keep": [1, 13, 15, 23, 24, 158, 160, 162, 172, 183, 220, 224, 238, 256, 281, 387, 389], "mind": [1, 4], "half": [1, 16, 192, 196, 303, 389], "precis": [1, 4, 110, 115, 256, 267, 300, 369, 386], "direct": [1, 4, 279, 368, 393], "fix": [1, 4, 7, 389], "possibli": [4, 159], "due": [], "transpos": [4, 28, 185], "aren": [], "guarante": [], "fit": [1, 184, 393], "requir": [1, 4, 256, 389, 390], "column": [1, 111, 135, 184], "inplac": 1, "expect": [1, 4, 261, 262, 263, 264, 265, 307, 310, 332, 386, 388], "answer": [], "copi": [1, 4, 6, 182, 214, 390], "simpli": [1, 4, 7, 302, 320, 326, 345, 354, 359, 386, 387], "catlas_saxpbi": 1, "axpby_impl_acceler": 1, "first": [1, 2, 3, 4, 5, 7, 99, 128, 132, 155, 157, 159, 182, 194, 225, 230, 237, 245, 256, 259, 269, 275, 311, 330, 338, 363, 365, 366, 367, 370, 386, 387, 390, 393], "mode": [1, 88, 280, 291, 293, 311, 315, 316], "e": [1, 5, 7, 105, 143, 209, 260, 261, 262, 264, 265, 269, 270, 272, 282, 300, 327, 328, 350, 355, 361, 364, 386, 389, 394], "match": [7, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 164, 284, 311, 331, 388, 390], "transposit": [], "data_s": [], "items": [], "flag": [1, 386, 390], "copy_inplac": 1, "copytyp": 1, "n": [1, 4, 27, 85, 86, 87, 111, 116, 118, 119, 121, 122, 125, 127, 135, 190, 220, 234, 238, 258, 259, 260, 261, 262, 264, 265, 268, 271, 274, 275, 301, 311, 337, 342], "incx": 1, "inci": 1, "great": 2, "But": 393, "criteria": 1, "luckili": 389, "alwai": [164, 245, 387], "With": 1, "final": [1, 3, 4, 5, 378], "singl": [1, 5, 107, 143, 149, 163, 181, 239, 259, 275, 386, 388, 392], "row_contigu": 1, "col_contigu": 1, "common": [1, 361, 386, 389], "hit": 1, "mileston": [], "enough": [1, 389], "run": [1, 2, 4, 5, 6, 7, 8, 180, 248, 260, 277, 362, 363, 365, 366, 367, 386, 389, 393, 394], "If": [1, 4, 7, 13, 14, 15, 16, 23, 24, 25, 26, 73, 76, 82, 84, 88, 91, 92, 93, 94, 98, 99, 107, 112, 114, 125, 126, 127, 130, 131, 132, 137, 146, 149, 158, 159, 160, 162, 163, 168, 169, 172, 178, 181, 182, 183, 187, 190, 198, 213, 214, 215, 220, 224, 226, 227, 230, 232, 237, 238, 240, 242, 246, 260, 261, 262, 269, 272, 273, 282, 284, 294, 299, 301, 303, 305, 307, 311, 329, 331, 342, 363, 386, 387, 389, 392, 393, 394], "plan": [1, 386], "stop": [1, 4, 16, 148, 171, 221, 387, 388], "enjoi": 1, "speed": 1, "appl": [1, 4, 6, 7, 393], "silicon": [1, 4, 6, 7, 393], "address": 1, "shade": 1, "languag": 1, "kernel": [1, 85, 86, 87, 258, 259, 274, 275, 386, 388], "written": 1, "help": [1, 4, 386, 393], "resourc": 1, "walkthrough": 1, "pipelin": 1, "specif": [1, 7, 387], "cpp": 1, "algorithm": [311, 368], "launch": [1, 388], "exactli": [1, 4, 284, 387], "mani": [1, 215, 261, 262, 266, 386, 389], "thread": 1, "pick": 1, "updat": [1, 3, 4, 5, 35, 83, 246, 260, 277, 278, 284, 289, 290, 291, 296, 361, 363, 366, 368, 369, 370, 374, 375, 376, 377, 378, 379, 386, 389], "assign": [1, 35, 359], "axpby_gener": 1, "buffer": [1, 164, 390], "constant": [1, 4, 7, 112, 113, 181, 256, 260, 269, 272, 300, 332, 342, 373, 375, 386, 390], "4": [1, 4, 97, 128, 146, 184, 185, 205, 250, 258, 259, 260, 270, 274, 275, 299, 310, 311, 313, 314, 315, 329, 386, 388, 391, 393], "5": [1, 3, 4, 7, 146, 169, 186, 258, 260, 263, 264, 265, 270, 274, 308, 311, 312, 315, 316, 341, 352, 356, 373, 375, 376, 386, 387, 388], "x_stride": 1, "6": [1, 4, 146, 205, 310, 314, 322, 323, 325, 332, 342, 346, 373, 386, 388, 391], "y_stride": 1, "7": [1, 4, 146, 184, 388], "ndim": [1, 128, 146, 311], "8": [1, 4, 7, 146, 184, 250, 259, 270, 275, 310, 330, 362, 363, 364, 365, 366, 367, 373, 386, 388, 391, 393], "uint": 1, "index": [1, 6, 8, 25, 35, 109, 111, 132, 163, 182, 226, 227, 237, 248], "thread_position_in_grid": 1, "convert": [1, 73, 77, 78, 79, 128, 299, 389, 390, 391], "instanti": [1, 5, 389], "uniqu": [1, 383], "host": 1, "name": [1, 149, 184, 185, 203, 204, 205, 206, 256, 269, 281, 284, 286, 388, 392], "identifi": [1, 245, 385], "instantiate_axpbi": 1, "type_nam": 1, "host_nam": 1, "axpby_general_": 1, "compil": [2, 6, 7, 100, 103, 387, 389], "mlx_ext": 1, "metallib": [1, 7], "see": [1, 4, 5, 7, 9, 10, 29, 30, 31, 32, 33, 36, 37, 38, 39, 40, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 146, 168, 203, 204, 250, 256, 260, 264, 267, 280, 298, 299, 302, 303, 304, 306, 307, 308, 311, 313, 314, 315, 316, 321, 322, 323, 347, 386, 387, 388, 391, 393], "later": [2, 7], "co": [1, 307, 387], "locat": [1, 295, 296, 393], "share": [6, 97, 184, 185], "register_librari": 1, "potenti": 169, "path": [2, 7, 170, 205, 206, 284], "tri": [], "load": [5, 6, 284], "hasn": [], "alreadi": [1, 2, 4], "static": 7, "object": [2, 9, 27, 46, 73, 83, 142, 205, 240, 245, 246, 250, 264, 310, 385], "why": 4, "packag": [1, 3, 5, 356], "process": [4, 87, 88, 246, 265, 266, 310, 385], "logic": [1, 155, 156, 157], "grid": [1, 163], "shown": 1, "below": [1, 7, 146, 234, 236, 250, 311, 389], "prepar": [1, 4], "carri": 1, "should": [1, 3, 4, 5, 7, 99, 112, 113, 115, 143, 170, 184, 227, 237, 239, 245, 256, 261, 262, 264, 265, 291, 297, 305, 331, 333, 338, 359, 385, 386, 387, 389, 390, 394], "d": [1, 4, 98, 99, 136, 146, 159, 163, 180, 226, 234, 235, 236, 247, 265, 268, 271, 301, 362, 365, 367, 393], "ostringstream": 1, "kname": 1, "axpby_": 1, "general_": 1, "type_to_nam": 1, "make": [1, 2, 4, 5, 7, 159, 177, 208, 256, 375, 376, 378, 379, 386, 389, 391, 393], "sure": [1, 2, 4, 7, 256, 386], "look": [1, 4], "folder": 1, "get_colocated_mtllib_path": 1, "get_kernel": 1, "str": [1, 88, 132, 146, 149, 163, 170, 202, 203, 204, 205, 206, 237, 245, 247, 277, 278, 281, 282, 284, 286, 288, 294, 311, 315, 316, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342], "encod": [1, 114, 303, 307, 310, 331], "compute_encod": 1, "get_command_encod": 1, "setcomputepipelinest": 1, "those": [1, 4, 256], "nelem": 1, "set_array_buff": 1, "setbyt": 1, "sizeof": 1, "threadgroup": 1, "higher": [1, 136, 338, 387], "than": [1, 4, 73, 88, 99, 102, 114, 133, 134, 144, 145, 159, 168, 246, 303, 309, 311, 338, 341, 353, 363, 368, 386, 387, 393], "max": [1, 146, 161, 274, 275, 298, 325, 330, 332, 333, 338, 342, 344, 346, 363, 367, 386, 387, 393], "allow": [1, 142, 256, 296, 359, 381, 388, 391], "tgp_size": 1, "min": [1, 146, 173, 298, 325, 344, 346], "maxtotalthreadsperthreadgroup": 1, "3d": [1, 260, 265, 311], "mtl": 1, "group_dim": 1, "grid_dim": 1, "divid": [1, 35, 130, 184], "among": 1, "dispatchthread": 1, "few": [1, 4, 5, 6, 389, 391], "thing": [1, 4], "note": [1, 4, 7, 14, 83, 85, 86, 115, 123, 124, 137, 146, 164, 184, 187, 256, 300, 311, 390, 392], "befor": [1, 4, 7, 25, 182, 281, 310, 370, 388, 389], "move": [1, 174, 393], "track": [1, 256, 260], "activ": [1, 7, 164, 264, 309, 310, 343, 352, 353, 355, 386], "u": [1, 273, 296, 381, 389], "command": [1, 2, 7], "instead": [1, 7, 256, 296, 307, 387, 389], "end_encod": 1, "end": [99, 167, 184, 259, 268, 271, 275, 309, 334, 341, 347, 352, 353, 378], "until": [1, 389, 391], "limit": [1, 82, 168, 169, 388], "flush": 1, "enqueu": [], "commit": [], "associ": [1, 205, 206, 389], "suggest": [], "deeper": [], "dive": [], "studi": [], "come": [1, 4, 387], "far": 361, "built": [1, 7, 389], "includ": [1, 91, 92, 93, 94, 164, 165, 169, 272, 278, 290, 299, 332, 386, 387, 388, 391, 392, 394], "forward": [1, 237, 386, 389], "diff": 1, "push": 1, "along": [1, 23, 24, 83, 84, 91, 92, 93, 94, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 146, 198, 213, 215, 219, 226, 227, 230, 231, 232, 256, 301, 324], "similarli": [1, 7, 159, 387, 389], "scale_arr": 1, "contribut": 1, "tangent_x": 1, "tangent_i": 1, "revers": [1, 37, 38, 39, 40, 91, 92, 93, 94, 233, 307], "arg": [1, 4, 9, 10, 107, 205, 206], "push_back": 1, "fulli": [1, 6, 386, 390, 393], "overal": 1, "directori": [1, 4, 7], "extens": [1, 149, 170, 288, 392], "h": [1, 85, 86, 146, 259, 260, 262, 264, 265, 268, 271, 275, 301, 387, 389], "mlx_sample_extens": 1, "__init__": [1, 4, 5, 8, 9, 10, 27, 248, 256, 359], "py": [1, 4, 7], "cmakelist": 1, "txt": 1, "setup": [1, 3, 5, 7, 386], "hold": [1, 4, 9, 10, 146, 386], "instal": 1, "pybind11": [], "sinc": [1, 4, 5, 359, 368, 377, 390, 393], "compon": [1, 4], "etc": [1, 184, 256, 311], "pybind11_modul": [], "m": [1, 4, 7, 111, 146, 234, 258, 259, 274, 275, 362, 386], "doc": [1, 5], "sampl": [1, 3, 4, 148, 186, 187, 188, 190, 192, 195, 196, 313, 314, 315, 316, 318, 319, 332, 338, 342, 383, 386], "_a": 1, "pos_onli": [], "kw_onli": 1, "none": [1, 4, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 171, 172, 173, 174, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 204, 205, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 240, 241, 242, 243, 245, 246, 248, 258, 259, 267, 274, 275, 277, 281, 282, 289, 294, 297, 301, 307, 310, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 363, 381, 388], "r": [1, 4, 147, 237, 264, 268], "pbdoc": [], "most": [1, 187, 256, 372, 386, 387, 388, 389], "complex": [1, 123, 124, 125, 126, 127, 245, 250, 256, 296, 386, 387], "bell": 1, "whistl": 1, "liter": [1, 311, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342], "string": [1, 390, 392], "modul": [1, 4, 5, 244, 299, 305, 310, 356, 372, 385, 386, 389], "ensur": [1, 7, 337], "caster": 1, "find_packag": 1, "config": 1, "add_librari": 1, "sourc": [1, 2, 56, 174, 233], "target_sourc": 1, "cmake_current_list_dir": 1, "header": 1, "target_include_directori": 1, "target_link_librari": 1, "attach": 1, "conveni": [1, 5, 142], "mlx_build_metallib": 1, "target": [1, 237, 329, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 386], "destin": [1, 56, 174], "automat": [1, 6, 149, 391, 392, 393], "practic": [1, 386], "mlx_build_met": [1, 7], "mlx_ext_metallib": 1, "titl": 1, "include_dir": 1, "project_source_dir": 1, "mlx_include_dir": 1, "output_directori": 1, "cmake_library_output_directori": 1, "add_depend": 1, "endif": 1, "pybind11_add_modul": [], "build_shared_lib": 1, "target_link_opt": 1, "wl": 1, "rpath": 1, "loader_path": 1, "onc": [1, 386], "describ": [1, 389], "util": [1, 4, 6, 7, 205, 256], "__name__": [1, 4], "__main__": [1, 4], "descript": [1, 4, 250], "ext_modul": 1, "cmakeextens": 1, "cmdclass": 1, "build_ext": 1, "cmakebuild": 1, "package_dir": [], "package_data": 1, "dylib": 1, "zip_saf": 1, "fals": [1, 4, 13, 14, 15, 23, 24, 30, 31, 32, 33, 37, 38, 39, 40, 52, 53, 54, 55, 59, 71, 75, 76, 83, 87, 91, 92, 93, 94, 137, 142, 146, 149, 158, 160, 162, 163, 169, 172, 183, 220, 224, 238, 241, 245, 246, 250, 269, 270, 272, 273, 282, 284, 294, 297, 299, 303, 307, 310, 311, 329, 332, 363, 374, 390], "python_requir": 1, "even": [1, 4, 83, 386, 389, 390], "though": [1, 4, 386, 389, 390], "j8": 1, "libmlx_ext": 1, "cpython": 1, "3x": 1, "darwin": 1, "pip": [1, 7], "after": [1, 4, 5, 25, 128, 130, 182, 184, 260, 269, 272, 277, 278, 282, 284, 291, 294, 295, 296, 297, 310, 341, 386, 393], "plai": [1, 4], "ones": [1, 4, 179, 205, 234, 295, 296, 299, 388], "b": [1, 2, 4, 12, 14, 76, 101, 102, 104, 130, 133, 134, 136, 137, 144, 145, 146, 154, 155, 157, 159, 161, 173, 175, 180, 184, 223, 230, 237, 273, 301, 311, 324, 387, 388, 389, 390, 391, 392, 393], "f": [1, 2, 3, 5, 146, 256, 271, 366, 386, 390], "item": [1, 3, 4, 5, 246, 389, 390, 391], "true": [1, 3, 4, 14, 37, 38, 39, 40, 76, 83, 91, 92, 93, 94, 114, 137, 142, 146, 149, 163, 169, 185, 213, 241, 245, 246, 250, 256, 260, 261, 262, 268, 269, 270, 271, 272, 273, 281, 282, 284, 291, 294, 299, 301, 303, 307, 310, 311, 329, 337, 363], "quick": [1, 6], "benchmark": [1, 386], "compar": [1, 76, 386], "time": [1, 4, 7, 169, 231, 256, 258, 259, 268, 271, 274, 275, 301, 386, 387, 389, 393], "set_default_devic": 1, "256": [1, 5], "512": [1, 2, 4, 310, 393], "random": [1, 2, 3, 4, 5, 6, 258, 259, 260, 270, 274, 275, 284, 291, 386, 387, 393, 394], "normal": [1, 3, 4, 112, 113, 190, 195, 256, 258, 259, 260, 269, 270, 272, 274, 275, 300, 310, 313, 315, 390, 393], "bench": 1, "warm": [1, 386], "rang": [1, 2, 3, 4, 5, 7, 16, 128, 148, 314, 316, 322, 323, 361, 375, 376, 377, 378, 379, 383, 386, 387, 389, 393], "100": [1, 3, 4, 378, 386, 387, 389, 393], "5000": 1, "simple_tim": 1, "custom_tim": 1, "3f": [1, 5, 386], "custom": [1, 310], "114": 1, "109": 1, "modest": 1, "improv": [1, 2, 4, 362, 363, 364, 365, 366, 367, 373, 386], "awai": [1, 4], "good": [1, 7, 386, 393], "nn": [1, 4, 5, 205, 246, 256, 356, 359, 361, 370, 372, 386, 389], "grad": [1, 3, 5, 237, 361, 369, 386, 387, 388, 389, 391], "full": [1, 5, 61, 74, 88, 213, 295, 296, 332, 386, 389], "profil": 2, "kei": [2, 4, 115, 186, 187, 188, 190, 191, 192, 194, 195, 196, 245, 246, 281, 282, 294, 297, 370, 383, 385, 387], "build": [2, 4, 6, 315, 359, 386], "mlx_metal_debug": [2, 7], "option": [2, 4, 13, 15, 16, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 77, 78, 79, 83, 84, 85, 86, 87, 88, 91, 92, 93, 94, 97, 98, 99, 111, 112, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 131, 132, 135, 140, 141, 146, 147, 148, 149, 158, 160, 162, 163, 169, 172, 178, 181, 182, 183, 184, 185, 186, 187, 188, 190, 191, 192, 194, 195, 196, 198, 199, 213, 214, 215, 218, 219, 220, 224, 226, 230, 232, 233, 234, 235, 236, 237, 238, 240, 242, 245, 246, 258, 259, 260, 261, 262, 268, 271, 273, 274, 275, 277, 281, 282, 284, 289, 294, 297, 299, 301, 303, 307, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 362, 363, 364, 365, 366, 367, 368, 370, 373, 374, 375, 383, 386, 392, 394], "debug": 2, "record": [2, 166, 389], "dure": [2, 83, 263, 264, 265, 311, 390], "inspect": [2, 386, 391], "label": [2, 3, 331, 338], "queue": 2, "readabl": 2, "start_captur": 2, "initi": [2, 3, 4, 256, 260, 269, 270, 272, 273, 298, 300, 312, 313, 314, 315, 316, 317, 318, 319, 359, 370, 375, 376, 378, 379, 386, 389], "gpu": [2, 6, 386, 388, 393], "main": [6, 99, 111, 246, 256], "jane": [], "develop": [6, 7], "gputrac": [2, 170], "arang": [146, 250, 311, 388, 390], "10": [2, 4, 5, 151, 200, 205, 246, 256, 284, 356, 377, 379, 386, 388], "20": 146, "30": 363, "40": [], "stop_captur": 2, "replai": 2, "trace": [2, 386], "view": [2, 390], "overview": 2, "oper": [2, 4, 6, 8, 34, 77, 78, 79, 87, 115, 213, 221, 227, 248, 256, 310, 368, 386, 387, 388, 389, 390, 391, 393, 394], "checkout": [2, 386], "document": [2, 6, 61, 74, 203, 204, 250, 386, 387, 388], "inform": [2, 4, 5, 7, 203, 204, 250, 256, 260, 267, 297, 387, 393], "skip": 2, "save": [2, 4, 6, 149, 170, 184, 203, 204, 205, 206, 288, 389], "project": [2, 4, 297], "us": [2, 3, 4, 5, 6, 7, 16, 35, 97, 100, 102, 114, 128, 146, 147, 159, 164, 165, 166, 168, 184, 185, 198, 199, 245, 250, 256, 259, 264, 266, 267, 268, 271, 273, 275, 277, 281, 288, 295, 297, 299, 301, 303, 307, 310, 311, 315, 316, 322, 323, 330, 356, 359, 361, 362, 363, 365, 366, 367, 368, 369, 370, 383, 385, 386, 387, 388, 391, 393], "cmake": [2, 7], "mkdir": [2, 7], "cd": [2, 7], "dmlx_metal_debug": 2, "ON": [2, 7], "g": [2, 7, 146, 184, 271, 355, 373, 374, 389, 394], "xcodeproj": 2, "select": [2, 7, 232, 241, 277, 281, 289], "metal_captur": 2, "exampl": [2, 3, 4, 5, 16, 35, 128, 146, 147, 222, 226, 256, 258, 259, 260, 270, 274, 275, 282, 284, 291, 294, 311, 312, 313, 314, 315, 316, 317, 318, 319, 329, 331, 338, 356, 361, 370, 375, 376, 377, 378, 379, 383, 387, 388, 389, 390, 391, 392], "schema": 2, "implement": [3, 5, 114, 115, 146, 266, 281, 297, 303, 305, 307, 309, 310, 311, 353, 362, 363, 364, 365, 367, 368, 369, 381, 386, 387, 390], "basic": [3, 200, 387], "model": [3, 5, 6, 205, 244, 246, 256, 277, 280, 282, 284, 288, 291, 293, 294, 295, 297, 310, 356, 359, 361, 369, 370, 372, 386, 389], "problem": [3, 5, 256], "metadata": [3, 149, 203, 204], "num_featur": [3, 260], "num_exampl": 3, "1_000": 3, "num_it": 3, "10_000": 3, "iter": [3, 5, 246, 383, 386, 389], "sgd": [3, 5, 361, 368, 370, 375, 376, 379, 386], "lr": [3, 368], "01": [3, 326, 366], "rate": [3, 362, 363, 364, 365, 366, 367, 368, 373, 374], "ll": [3, 5, 334, 386, 387], "synthet": 3, "dataset": [3, 389], "matrix": [3, 41, 97, 98, 111, 135, 146, 147, 159, 163, 184, 185, 190, 299, 317, 356], "ground": [3, 4, 331, 341], "truth": [3, 331, 341], "w_star": 3, "valu": [3, 4, 11, 14, 16, 23, 24, 46, 73, 76, 82, 111, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 131, 137, 146, 148, 181, 186, 187, 188, 190, 191, 192, 195, 196, 203, 226, 227, 237, 240, 244, 245, 246, 250, 259, 263, 264, 265, 270, 273, 275, 281, 297, 298, 308, 309, 310, 312, 329, 330, 331, 332, 333, 334, 336, 337, 338, 339, 340, 341, 353, 359, 363, 366, 375, 376, 378, 379, 387], "gaussian": [3, 267, 321, 322, 323, 332], "nois": 3, "noisi": 3, "ep": [3, 112, 113, 260, 269, 270, 272, 300, 330, 332, 342, 362, 363, 364, 365, 366, 367, 373], "1e": [3, 5, 14, 137, 260, 269, 270, 272, 300, 330, 332, 342, 362, 363, 364, 365, 366, 367, 370, 373, 375, 376, 377, 378, 379], "weight": [3, 85, 86, 87, 112, 113, 246, 256, 284, 288, 299, 329, 331, 359, 363, 366, 368, 370, 374, 387, 389], "squar": [3, 4, 113, 135, 201, 216, 237, 246, 256, 300, 339, 341, 362, 363, 365, 366, 367, 387, 390], "loss": [3, 5, 237, 256, 361, 386, 387, 389], "loss_fn": [3, 5, 361, 386, 387], "w": [3, 86, 97, 184, 185, 237, 259, 260, 262, 264, 265, 273, 275, 374, 387], "mean": [3, 4, 5, 113, 190, 191, 237, 256, 260, 269, 282, 300, 318, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 386, 387, 390], "grad_fn": [3, 386, 387], "randomli": [3, 4, 263, 264, 265], "Then": [3, 7], "repeatedli": 3, "_": [2, 3, 4, 256, 375, 376, 377, 378, 379, 383, 386, 389, 393], "verifi": [3, 7], "close": [3, 6, 7, 14, 137], "error_norm": 3, "5f": 3, "someth": [3, 4, 388], "00005": 3, "00364": 3, "complet": [3, 4, 7, 169, 295, 296, 387, 393], "logist": [3, 209, 322, 323, 349], "github": [3, 5, 7, 386], "repo": [3, 5, 7, 386], "enabl": [2, 4, 7, 83, 103, 374], "larg": [4, 256, 297, 337, 386, 389], "ish": 4, "transform": [4, 6, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 244, 256, 260, 269, 272, 273, 281, 282, 294, 299, 303, 388], "compromis": 4, "eas": 4, "llama": 4, "famili": 4, "less": [4, 25, 145, 182, 303, 341], "200": [4, 377], "line": [4, 389, 390], "python": [2, 4, 46, 64, 73, 107, 245, 246, 247, 359, 369, 370, 372, 385, 387, 390], "neural": [4, 6, 266, 313, 314, 343, 356, 359, 373], "network": [4, 6, 260, 264, 266, 313, 314, 356, 359, 373], "concis": 4, "architectur": [4, 7, 256, 296, 393], "notabl": [4, 6], "rope": [4, 256], "posit": [4, 25, 99, 114, 128, 132, 141, 174, 182, 190, 237, 246, 256, 261, 262, 297, 303, 307, 332, 342], "cach": [4, 164, 165, 168, 386], "concaten": 4, "llamaattent": 4, "self": [4, 5, 8, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 248, 256, 343, 359], "dim": [4, 114, 115, 266, 269, 270, 272, 297, 300, 303, 307, 310], "num_head": [4, 297, 310], "super": [4, 5, 256, 359], "tradit": [4, 114, 264, 265, 303], "query_proj": 4, "bia": [4, 97, 112, 184, 185, 246, 256, 261, 262, 268, 271, 272, 273, 282, 284, 294, 297, 299, 301, 365, 366, 367, 370, 387], "key_proj": 4, "value_proj": 4, "out_proj": [4, 359], "__call__": [4, 5, 256, 359], "queri": [4, 115, 297], "mask": [4, 115, 291, 297, 388], "extract": [4, 41, 98, 99, 256, 281, 359], "l": [4, 5, 256, 258, 260, 261, 268, 271, 274, 301, 341], "reshap": [4, 146, 311, 388], "combin": 4, "key_cach": 4, "value_cach": 4, "sqrt": [4, 105, 115, 260, 269, 270, 272, 273, 300, 307, 313, 314, 315, 316, 362, 364, 365, 366, 373, 386], "score": [4, 115, 338], "softmax": [4, 115, 256, 328, 331], "values_hat": 4, "rm": [4, 7, 113, 363], "swiglu": 4, "rmsnorm": [4, 256], "llamaencoderlay": 4, "mlp_dim": [4, 310], "norm1": 4, "norm2": 4, "linear1": 4, "linear2": 4, "linear3": 4, "sigmoid": [4, 256, 306, 322, 323, 327, 349], "instanc": [4, 35, 184, 247, 256, 270, 277, 278, 279, 282, 284, 285, 286, 291, 294, 295, 296, 305, 359, 390], "embed": [4, 256, 303, 307, 330], "emb": [4, 266, 307], "token": [4, 266], "num_lay": [4, 5, 361], "vocab_s": 4, "norm": [4, 113, 269, 342, 367, 368], "multiheadattent": [4, 256], "create_additive_causal_mask": 4, "list": [4, 9, 13, 15, 27, 67, 73, 77, 78, 79, 80, 83, 84, 87, 107, 117, 118, 120, 121, 123, 124, 126, 127, 131, 132, 143, 146, 158, 160, 162, 163, 172, 178, 181, 183, 186, 187, 188, 190, 191, 192, 195, 196, 203, 213, 215, 219, 220, 224, 230, 231, 233, 237, 238, 239, 242, 245, 247, 256, 282, 284, 285, 286, 287, 292, 294, 295, 296, 359, 365, 366, 367, 368, 377, 385, 386, 387, 389], "still": [4, 7, 146, 386, 389], "consid": [4, 14, 76, 137, 245, 246, 269, 385], "train": [4, 5, 256, 260, 263, 264, 265, 280, 282, 294, 313, 314], "ignor": [4, 35, 82, 83, 107, 363], "whatsoev": 4, "rest": [4, 114, 246, 303], "subsect": 4, "prompt": 4, "autoregress": 4, "yield": [4, 5, 383], "temp": 4, "causal": 4, "append": [4, 159, 386, 389], "store": 4, "per": [4, 5, 97, 184, 185, 260, 269, 270, 272, 300, 381, 386, 389], "care": [4, 389], "last": [4, 26, 73, 112, 113, 118, 121, 123, 124, 126, 127, 128, 136, 147, 159, 187, 214, 230, 261, 262, 264, 265, 269, 311, 390], "logit": [4, 187, 329, 331, 386], "next": [1, 4, 5, 168], "categor": 4, "lazili": [4, 256], "noth": [4, 256, 389], "yet": [4, 146, 256, 359, 370, 387, 388, 389, 391], "forc": [4, 5, 256, 391], "choos": [4, 114, 303], "pars": 4, "feed": 4, "loop": [4, 5, 386, 387, 389], "unsqueez": 4, "sequenc": [4, 13, 15, 30, 31, 52, 53, 54, 55, 59, 67, 70, 71, 75, 80, 87, 109, 117, 118, 120, 121, 123, 124, 126, 127, 131, 158, 160, 162, 172, 178, 183, 186, 187, 188, 190, 191, 192, 195, 196, 199, 213, 215, 218, 220, 224, 230, 231, 233, 238, 242, 260, 261, 268, 271, 301, 310, 383, 393], "length": [4, 218, 260, 261, 268, 271, 301, 377], "len": [4, 118, 121, 124, 127, 377], "overwrit": 4, "discard": [4, 245], "old": 4, "moment": [4, 87, 363, 365, 366, 367], "anymor": 4, "everyth": 4, "small": [4, 110, 112, 113, 260, 269, 272, 300, 332, 337, 342, 386, 393], "12": [4, 377], "8192": 4, "1024": 4, "actual": [4, 16, 284, 359, 389], "materi": [4, 6], "could": [4, 256], "20_000": 4, "machin": [4, 6, 7, 373], "8gb": 4, "ram": 4, "32": [4, 5, 184, 185, 250, 259, 275, 300, 386], "44": 4, "doubl": 4, "bracket": 4, "becaus": [4, 164, 256, 389], "batch": [4, 159, 190, 260, 261, 262, 264, 265, 268, 271, 297, 301, 311, 389], "zip": [4, 5], "haven": 4, "anyth": [4, 237, 389], "result": [4, 16, 35, 73, 83, 97, 112, 113, 136, 146, 149, 159, 180, 185, 190, 198, 200, 219, 230, 231, 241, 246, 307, 386, 387, 390], "similar": [4, 142, 246, 295, 296, 297, 330, 390, 392], "runtim": [4, 386], "section": [4, 7, 215, 342, 386, 387], "access": [4, 46, 256, 359, 370, 389, 393], "origin": [4, 99, 260, 290, 313, 314, 315, 316, 362, 363, 364, 365, 367, 368, 390], "sentencepiec": 4, "pytorch": [4, 6, 269, 387], "compat": [4, 187, 190, 392], "npz": [4, 149, 205, 206, 284, 288, 392], "file": [4, 7, 149, 202, 203, 204, 205, 206, 284, 288, 387, 392], "directli": [1, 4], "argpars": 4, "itertool": [4, 246], "starmap": [4, 246], "np": [4, 5, 390, 391], "torch": [4, 390], "map_torch_to_mlx": 4, "tok_embed": 4, "elif": 4, "replac": [4, 295, 296, 310, 341], "attention_norm": 4, "ffn_norm": 4, "wq": 4, "wk": 4, "wv": 4, "wo": 4, "w1": 4, "w2": 4, "w3": 4, "ffn": 4, "separ": [4, 61, 74, 269, 338], "submodul": [4, 5, 256, 278, 282, 283, 294, 296], "feed_forward": 4, "parser": 4, "argumentpars": 4, "add_argu": 4, "torch_weight": 4, "output_fil": 4, "parse_arg": 4, "state": [4, 5, 256, 268, 271, 301, 361, 370, 383, 386], "savez": [4, 288, 392], "k": [4, 41, 98, 111, 115, 232, 234, 235, 236, 258, 273, 274, 282], "v": [4, 88, 115, 256, 282, 390], "left": [4, 114, 146, 184, 258, 259, 267, 274, 275, 303, 311, 322, 323, 332, 334, 342], "disk": 4, "text": [4, 258, 259, 268, 271, 274, 275, 276, 301, 309, 313, 314, 315, 316, 325, 332, 333, 334, 337, 338, 341, 343, 344, 347, 348, 352, 353, 363, 368], "format": [4, 149, 202, 203, 204, 205, 206, 390], "dictionari": [4, 83, 149, 203, 204, 245, 256, 281, 290, 295, 296, 371, 385, 392], "represent": [4, 184, 245, 247], "tree_unflatten": 4, "helper": [4, 386], "weight_fil": 4, "incur": 4, "sever": [4, 85, 86, 87, 205, 206, 386, 392], "futur": [4, 299, 388, 389], "pth": 4, "current": [4, 6, 7, 85, 86, 87, 165, 184, 256, 363, 389], "around": 4, "m1": [4, 386, 387, 393], "ultra": 4, "7b": 4, "me": 4, "ishmael": 4, "year": 4, "ago": 4, "never": [4, 389], "long": 4, "info": [4, 7], "247": 4, "press": [4, 146], "enter": 4, "littl": 4, "monei": 4, "my": [4, 7], "purs": 4, "greater": [4, 25, 110, 134, 182, 309, 353], "consequ": 4, "walk": 4, "down": 4, "gower": 4, "street": 4, "afternoon": 4, "heavi": 4, "rain": 4, "saw": [4, 387], "off": [4, 7, 389], "man": 4, "rag": 4, "who": 4, "sat": 4, "upon": [4, 246], "hi": [4, 271], "bundl": 4, "hard": 4, "wet": 4, "he": [4, 315, 316], "were": [4, 393], "cry": 4, "watch": [4, 386], "him": 4, "observ": 4, "numer": [4, 112, 113, 146, 154, 158, 213, 260, 269, 270, 272, 300, 330, 332, 342, 362, 363, 364, 365, 366, 367, 373, 386, 389], "crowd": 4, "wa": [4, 170, 389], "hurri": 4, "437": 4, "330": 4, "second": [4, 99, 155, 157, 159, 225, 237, 259, 275, 330, 338, 363, 365, 366, 367, 387, 393], "spent": 4, "amount": [4, 166, 258, 274], "39": 4, "By": [4, 289, 387, 390], "bigger": [4, 363], "remain": [4, 237, 263, 264, 265], "almost": 4, "nobodi": 4, "took": 4, "least": [4, 77, 78, 79, 82, 147, 184], "notic": [4, 387, 392], "distanc": [4, 342], "had": 4, "doubt": 4, "minut": 4, "straight": 4, "slowli": 4, "rais": [4, 146, 169, 215, 284], "ey": 4, "speak": [4, 146], "resum": 4, "postur": 4, "stood": 4, "feel": 4, "pain": 4, "heart": 4, "smile": 4, "face": 4, "am": 4, "someon": 4, "three": [4, 79, 311], "quarter": 4, "hour": 4, "made": 4, "immedi": [4, 277], "repli": 4, "again": [4, 7, 256, 386], "hand": [4, 387, 389], "did": 4, "accustom": 4, "thu": [4, 256], "question": [4, 389], "reason": [4, 388], "tell": [4, 386, 390], "understand": [4, 313, 314], "579": 4, "690": 4, "num": [4, 148, 194], "500": [4, 393], "628": 4, "went": 4, "nervou": 4, "trembl": 4, "told": 4, "And": [4, 311], "perhap": [1, 4], "surpris": 4, "matter": [4, 256], "shall": 4, "anyhow": 4, "friend": 4, "ye": 4, "slight": [4, 389], "kind": 4, "longer": [4, 88, 387], "soon": 4, "unless": [4, 14, 137, 146, 359], "unlik": [4, 14, 137, 264, 265, 290], "strang": 4, "amus": 4, "That": 4, "secret": 4, "disappoint": 4, "mine": 4, "cannot": [4, 82, 388, 390], "happi": 4, "ask": 4, "shop": 4, "bui": 4, "food": 4, "633": 4, "21": [4, 379], "475": 4, "su": 4, "j": [4, 7, 146, 264, 364, 365, 367], "lu": 4, "pan": 4, "murtadha": 4, "wen": 4, "liu": 4, "2021": 4, "roform": [4, 303], "enhanc": [4, 303, 389], "rotari": [4, 114, 303], "arxiv": [4, 269, 270, 272, 276, 300, 323, 343, 362, 368], "preprint": [4, 362, 368], "2104": 4, "09864": 4, "zhang": 4, "sennrich": 4, "2019": [4, 366], "root": [4, 113, 201, 216, 300], "advanc": [4, 386], "system": [4, 7, 164, 165], "shazeer": 4, "2020": 4, "glu": [4, 256], "variant": [4, 341, 367], "2002": 4, "05202": 4, "classifi": 5, "mnist": 5, "As": [5, 35, 226, 256, 386], "mlp": [5, 256, 310, 361], "inherit": [5, 385], "standard": [5, 46, 73, 159, 188, 191, 220, 310, 313, 315, 318, 391], "idiom": [5, 386], "input_dim": [5, 256, 273, 299], "hidden_dim": [5, 359, 361], "output_dim": [5, 256, 273, 299], "layer_s": 5, "idim": 5, "odim": 5, "maximum": [5, 23, 35, 82, 91, 166, 169, 256, 302, 307, 322, 323, 326, 345, 359, 389], "cross": [5, 87, 329, 331], "entropi": [5, 329, 331], "sub": [5, 99, 194], "commonli": [5, 295, 356, 386], "cross_entropi": [5, 256], "accuraci": 5, "valid": [5, 88, 128, 240, 245, 282, 294, 385], "eval_fn": 5, "argmax": 5, "loader": 5, "num_class": [5, 361], "batch_siz": [5, 361], "num_epoch": [5, 361], "learning_r": [5, 361, 362, 363, 364, 365, 366, 367, 368, 370, 373, 374, 375, 376, 377, 378, 379, 386], "train_imag": [5, 361], "train_label": [5, 361], "test_imag": 5, "test_label": 5, "shuffl": 5, "minibatch": 5, "batch_iter": [5, 361], "perm": 5, "permut": 5, "id": [5, 7], "put": [5, 386], "trainabl": [5, 244, 256, 359], "loss_and_grad_fn": [5, 361, 386, 387], "value_and_grad": [5, 256, 295, 359, 361, 372, 386, 387, 390, 391], "epoch": 5, "test": [5, 7], "confus": 5, "decent": 5, "95": 5, "brought": 6, "research": 6, "except": [6, 111, 122, 123, 125, 126, 127, 269, 284, 388, 390], "featur": [6, 85, 86, 87, 114, 260, 268, 269, 270, 271, 272, 273, 299, 300, 301, 303, 310, 311, 386, 389], "differ": [6, 142, 223, 341, 387], "lazi": [6, 359, 391], "multi": [6, 115, 261, 262, 388, 390], "cpu": [6, 147, 386, 393], "inspir": 6, "jax": [6, 383], "arrayfir": 6, "unifi": 6, "live": [6, 393], "guid": [1, 6], "convers": 6, "regress": [6, 337], "layer": [6, 112, 256, 258, 259, 264, 265, 268, 269, 271, 272, 273, 274, 275, 291, 296, 299, 301, 305, 310, 355, 359], "perceptron": 6, "llm": 6, "infer": [6, 131, 149], "fft": 6, "algebra": 6, "tree": [6, 83, 107, 132, 237, 240, 245, 246, 247, 369, 370, 372, 381, 387], "debugg": 6, "pypi": 7, "meet": 7, "seri": 7, "chip": 7, "nativ": 7, "maco": 7, "13": 7, "recommend": [7, 169, 368], "14": 7, "sonoma": 7, "conda": 7, "forg": 7, "distribut": [7, 186, 187, 188, 190, 191, 195, 196, 273, 313, 314, 315, 316, 318, 319, 332, 335, 340, 342, 356], "probabl": [7, 192, 263, 264, 265, 299, 329, 331, 335, 393], "platform": 7, "processor": 7, "arm": 7, "i386": 7, "switch": 7, "17": 7, "clang": 7, "24": 7, "xcode": 7, "15": [7, 146, 386], "sdk": 7, "environ": [7, 100, 103], "via": [7, 369, 372, 389, 390], "rosetta": 7, "unam": 7, "p": [7, 186, 256, 263, 264, 265, 342, 365, 367], "clone": 7, "git": 7, "com": 7, "ml": 7, "explor": 7, "nanobind": [1, 7, 310], "http": [7, 269, 270, 272, 276, 300, 323, 343], "wjakob": 7, "env": 7, "cmake_build_parallel_level": 7, "edit": [7, 296], "unittest": 7, "discov": 7, "stub": 7, "dev": [1, 7], "generate_stub": 7, "either": [7, 12, 61, 73, 74, 82, 101, 102, 104, 130, 133, 134, 144, 145, 146, 154, 159, 161, 173, 175, 223, 237, 259, 275, 305, 311, 315, 316], "libmlx": 7, "preprocessor": 7, "metal_path": 7, "mlx_build_test": 7, "mlx_build_exampl": 7, "mlx_build_benchmark": 7, "mlx_build_python_bind": 7, "multipl": [7, 112, 113, 159, 175, 184, 185, 297, 307, 376, 377, 379, 386, 389, 392], "wish": 7, "variabl": [7, 83, 100, 103, 132, 143, 237, 239, 240], "export": 7, "developer_dir": 7, "app": 7, "content": [7, 281, 386], "xcrun": 7, "macosx": 7, "show": [7, 250, 386], "unabl": 7, "tool": 7, "sudo": 7, "ouptut": 7, "finder": 7, "iterm": 7, "termin": 7, "click": 7, "uncheck": 7, "window": [7, 258, 259, 274, 275], "restart": 7, "grep": 7, "cmake_host_system_processor": 7, "arm64": 7, "x86_64": 7, "wipe": 7, "cahc": 7, "rf": 7, "devicetyp": 8, "attribut": [8, 9, 27, 248, 290, 359, 381], "kwarg": [9, 10, 205, 206, 394], "categori": [10, 250], "bool_": [10, 250], "integ": [10, 130, 142, 146, 181, 184, 185, 186, 192, 215, 230, 240, 250, 266, 289, 377, 388], "unsignedinteg": 10, "uint8": [10, 250], "uint16": [10, 250], "uint32": [10, 23, 24, 25, 26, 187, 250], "uint64": [10, 250], "signedinteg": [10, 142], "int8": [10, 250], "int32": [10, 16, 35, 128, 142, 146, 192, 250, 311, 388, 391], "int64": [10, 250], "inexact": [10, 142], "complexflo": 10, "complex128": 10, "issubdtyp": [10, 250], "absolut": [11, 14, 137, 322, 323, 341], "semant": [12, 80, 101, 102, 104, 133, 134, 144, 145, 154, 159, 161, 173, 175, 223, 393], "keepdim": [13, 15, 23, 24, 30, 31, 32, 33, 52, 53, 54, 55, 59, 71, 75, 146, 158, 160, 162, 172, 183, 213, 220, 224, 238], "reduct": [13, 15, 158, 160, 172, 183, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342], "reduc": [13, 15, 23, 24, 158, 160, 162, 172, 183, 220, 224, 238, 260, 310, 337], "unspecifi": [13, 15, 16, 23, 24, 25, 26, 84, 91, 92, 93, 94, 131, 158, 160, 162, 172, 178, 182, 183, 198, 213, 214, 220, 224, 226, 232, 238, 242, 394], "entir": [13, 15, 23, 24, 158, 160, 162, 172, 183, 220, 224, 238, 264, 265], "singleton": [13, 15, 23, 24, 158, 159, 160, 162, 172, 183, 220, 224, 238], "rtol": [14, 137], "05": [14, 137, 260, 269, 270, 272, 300], "atol": [14, 137], "08": [14, 137, 330, 364, 365, 366, 367, 373], "equal_nan": [14, 76, 137], "approxim": [14, 267, 321, 322, 323], "comparison": [14, 104, 133, 134, 144, 145], "infinit": [14, 137], "equal": [14, 25, 76, 111, 134, 137, 145, 182, 192, 215, 270, 273], "sign": [14, 137, 250, 368], "nan": [14, 76, 137, 139], "ab": [14, 137, 146, 237, 269, 270, 272, 276, 300, 323, 343, 386], "array_equ": [14, 137], "rel": [14, 137, 363, 386], "toler": [14, 137], "boolean": [14, 76, 137, 138, 139, 140, 141, 155, 156, 157, 250, 293, 388], "interv": [16, 148, 192, 196], "increment": 16, "otherwis": [16, 87, 169, 245, 246, 282, 284, 294, 309, 310, 311, 329, 334, 341, 352, 353, 389, 390], "convent": [16, 88, 311, 366], "lead": [16, 386], "fraction": 16, "integr": [16, 226, 389], "invers": [17, 18, 19, 20, 21, 22, 106, 119, 120, 121, 122, 123, 124], "cosin": [17, 18, 89, 90, 330, 375, 377, 387], "hyperbol": [18, 20, 22, 90, 212, 229, 354], "sine": [19, 20, 211, 212, 387], "minimum": [24, 35, 82, 92, 307, 330, 375], "kth": [25, 182], "partit": 25, "order": [25, 87, 146, 182, 184, 232, 256, 269, 295, 305, 370, 386, 387], "undefin": [25, 182, 190, 388], "sort": [25, 26, 182, 232], "flatten": [25, 26, 91, 92, 93, 94, 146, 180, 182, 198, 214, 226, 227, 232, 245], "dimension": [27, 112, 113, 116, 117, 118, 119, 120, 121, 125, 126, 127, 258, 259, 260, 261, 262, 266, 273, 274, 275, 299, 307, 388, 390], "val": [27, 131], "tupl": [27, 61, 64, 74, 84, 86, 87, 102, 107, 109, 143, 146, 147, 181, 184, 199, 218, 237, 239, 245, 246, 247, 258, 259, 262, 274, 275, 284, 286, 305, 311, 363, 365, 366, 367, 368, 385, 387], "ndarrai": [27, 388, 389, 391], "properti": [28, 35, 43, 47, 57, 58, 64, 66, 290, 293, 371, 387], "argument": [28, 61, 74, 83, 107, 132, 237, 246, 256, 311, 383, 387, 392, 393, 394], "union": [29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 70, 71, 72, 74, 75, 77, 78, 79, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 140, 141, 194, 195, 222], "appli": [35, 114, 115, 246, 256, 258, 259, 260, 261, 262, 264, 265, 267, 269, 270, 272, 273, 274, 275, 276, 278, 291, 298, 299, 300, 301, 302, 304, 306, 308, 309, 311, 320, 321, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 356, 369, 372, 378, 381, 386], "regular": [35, 264, 343, 366, 386, 388], "idx": [35, 388], "correctli": 35, "syntax": [35, 388], "subtract": 35, "inclus": [37, 38, 39, 40, 91, 92, 93, 94, 128], "diagon": [41, 98, 111, 234, 235, 236], "axis1": [42, 72, 99, 225], "axis2": [42, 72, 99, 225], "start_axi": [45, 128], "end_axi": [45, 128], "datatyp": 47, "byte": [47, 57, 164, 165, 166, 168, 169, 250], "decim": [62, 200], "indices_or_sect": [67, 215], "nest": [73, 83, 256, 359, 385, 387], "ddof": [75, 220, 238], "ari": [77, 78, 79], "a_min": 82, "a_max": 82, "edg": [82, 181, 311, 386], "At": 82, "anoth": [82, 142, 159, 223, 241, 250, 256, 277, 386, 387, 388, 393], "fun": [83, 132, 143, 237, 239, 240, 386, 388, 389, 393], "callabl": [83, 132, 143, 237, 239, 240, 244, 245, 246, 277, 278, 281, 289, 301, 305, 310, 312, 313, 314, 315, 316, 317, 318, 319, 362, 363, 364, 365, 366, 367, 368, 373, 374, 375, 376, 377, 378, 379], "shapeless": 83, "dict": [83, 107, 149, 203, 204, 205, 287, 292, 295, 296, 359, 369, 370, 372, 385, 387, 392], "arbitrarili": [83, 256, 385, 387, 391], "leaf": [83, 245, 246, 281], "node": [83, 107, 240], "recompil": [83, 386], "chang": [83, 208, 295, 299, 311, 334, 341, 386, 390], "Not": [83, 386], "attempt": 83, "pad": [85, 86, 87, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 258, 259, 261, 262, 274, 275], "dilat": [85, 86, 87, 261, 262], "group": [85, 86, 87, 97, 115, 184, 185, 269, 299], "1d": [85, 87, 88, 203, 227], "convolut": [85, 86, 87, 88, 261, 262, 264, 265], "channel": [85, 86, 87, 260, 261, 262, 264, 265], "c_in": [85, 86, 87], "c_out": [85, 86, 87], "convolv": [85, 86, 87], "2d": [86, 87, 99, 184, 260, 264], "spatial": [86, 87, 258, 269, 274, 311], "symmetr": 86, "kernel_dil": 87, "input_dil": 87, "flip": [87, 88], "correl": [87, 264], "discret": [88, 116, 117, 118, 119, 120, 121, 125, 126, 127, 266], "swap": [88, 169, 225, 296, 299], "conv": 88, "filter": [88, 261, 262, 277, 281], "signal": [88, 311], "cumul": [91, 92, 93, 94], "th": [91, 92, 93, 94, 98, 111, 377], "bias": [97, 184, 185, 268, 271, 282, 294, 297], "group_siz": [97, 184, 185, 299], "64": [97, 184, 185, 250, 299], "configur": 97, "formal": [97, 184], "notat": [97, 245, 286], "quantiz": [97, 149, 185, 299], "w_i": [97, 184], "hat": [97, 184], "occupi": [97, 184, 185], "subarrai": [99, 215], "remov": [99, 159, 187, 218, 331], "insert": [99, 109, 393], "neg": [99, 128, 140, 274, 275, 297, 332, 340, 342, 388], "taken": [99, 226], "global": [100, 103, 193, 383, 386], "disabl": [100, 168, 386], "mlx_disable_compil": [100, 103, 386], "divis": [101, 130, 184], "quotient": [101, 102, 130], "remaind": 102, "fuction": 102, "faster": [1, 102, 321, 386, 387], "mathrm": [105, 209, 270], "frac": [105, 184, 209, 258, 259, 260, 263, 264, 265, 269, 270, 272, 273, 274, 275, 300, 313, 314, 315, 316, 330, 332, 334, 337, 348, 350, 362, 364, 365, 366, 367, 373], "pi": [105, 307, 387], "int_0": 105, "dt": 105, "erf": [106, 386], "exponenti": [108, 110, 304, 320, 347, 376], "ident": [111, 221, 256, 291], "zero": [111, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 163, 234, 235, 236, 243, 256, 258, 259, 263, 264, 265, 284, 312, 313, 314, 315, 316, 317, 318, 319, 356, 363, 388], "whose": [111, 244], "translat": [112, 272], "stabil": [112, 113, 260, 269, 270, 272, 300, 330, 332, 362, 363, 364, 365, 366, 367, 373], "traditino": 114, "rotat": [114, 303], "larger": [114, 303, 368], "unchang": [114, 221, 303], "consecut": [114, 184, 303], "angular": [114, 303], "frequenc": [114, 303, 307], "q": [115, 147], "head": [115, 297, 310], "attent": [115, 282, 297, 307, 310], "regardless": 115, "pre": 115, "tile": 115, "typic": [115, 266, 361, 386, 389], "One": [116, 119, 125, 201, 386, 387], "fourier": [116, 117, 118, 119, 120, 121, 125, 126, 127], "truncat": [116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 195], "dft": [116, 117, 118, 119, 120, 121, 125, 126, 127], "rfft": 122, "real": [122, 123, 124, 125, 126, 127], "rfft2": 123, "rfftn": 124, "silent": [125, 126, 127], "outsid": 128, "clamp": 128, "floor": 130, "argnam": [132, 237], "neither": [132, 237], "keyword": [132, 205, 206, 237, 246, 256, 383, 392, 394], "strict": [133, 144, 282, 284, 294], "ordinari": 136, "inifn": 138, "infin": [138, 140, 141, 274, 275, 367], "dtypecategori": [142, 250], "subtyp": [142, 250], "subdtyp": 142, "float64": 142, "too": [142, 386, 389], "ord": 146, "tabl": [146, 250, 266], "frobeniu": 146, "matric": [146, 147], "strictli": 146, "mathemat": 146, "variou": 146, "purpos": 146, "calcul": [146, 332, 338, 363], "fro": 146, "inf": [146, 297], "largest": [146, 232], "sing": 146, "smallest": 146, "singular": 146, "nuclear": 146, "_f": 146, "sum_": [146, 258, 259, 337], "a_": 146, "valueerror": [146, 284, 387], "refer": [146, 270, 276, 290, 313, 314, 315, 316, 323, 343, 388], "golub": 146, "van": 146, "loan": 146, "baltimor": 146, "md": 146, "john": 146, "hopkin": 146, "univers": 146, "1985": 146, "pg": 146, "la": 146, "9": [146, 331, 362, 365, 366, 367, 368, 370, 376, 379, 390], "74597": 146, "84804": 146, "41421": 146, "23607": [146, 147], "74166": 146, "24264": 146, "11": 146, "225": 146, "894427": 147, "447214": 147, "57771": 147, "50": 148, "evenli": 148, "return_metadata": 149, "binari": [149, 202, 203, 204, 205, 206, 309, 329, 353, 386], "npy": [149, 202, 392], "safetensor": [149, 204, 284, 288, 389, 392], "gguf": [149, 203, 392], "matadata": 149, "unsupport": 149, "tensor": [149, 230, 258, 259, 274, 275, 342, 390], "natur": [150, 152, 389], "logarithm": [150, 151, 152, 153], "log": [152, 154, 158, 327, 328, 332, 335, 337, 340, 351], "plu": 152, "exp": [110, 154, 158, 188, 213, 320, 335, 347, 348, 351, 386, 393], "stabl": [154, 158, 213, 337], "prepend": [2, 159], "report": [164, 169], "peak": 166, "begin": [166, 184, 259, 268, 271, 275, 309, 334, 341, 347, 352, 353], "program": 166, "free": 168, "reclaim": 168, "set_memory_limit": 168, "previou": [168, 169], "relax": 169, "task": [169, 337], "exceed": 169, "negat": 176, "beforehand": 180, "pad_with": 181, "constant_valu": 181, "pad_width": 181, "before_1": 181, "after_1": 181, "before_2": 181, "after_2": 181, "before_n": 181, "after_n": 181, "before_i": 181, "after_i": 181, "extend": [1, 181], "side": [181, 258, 259, 274, 275, 386], "smaller": [182, 368, 386], "everi": [184, 246, 379, 387], "particular": [184, 269], "w_1": 184, "w_g": 184, "align": [184, 259, 268, 271, 275], "max_i": 184, "min_i": 184, "textrm": [184, 267, 321, 324], "round": 184, "pack": [184, 185], "unsign": [184, 185, 250], "lower": [184, 192, 195, 196, 234, 319], "upper": [184, 192, 195, 196, 319], "1st": 184, "signific": 184, "2nd": 184, "dequant": 184, "w_q": 184, "whether": [170, 185, 268, 271, 281, 297, 301, 329, 332, 338], "prng": [186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 383], "num_sampl": 187, "unnorm": [187, 329, 331], "draw": 187, "cdf": [188, 267, 321], "accord": [188, 241, 297, 313, 314, 315, 316], "seed": 189, "loc": 191, "deviat": [191, 220, 313, 315, 318], "low": [192, 196, 319, 356], "high": [192, 196, 256, 266, 319, 356], "bound": [192, 195, 196, 267, 319, 386, 388, 393], "roadcast": 192, "domain": 195, "uniformli": 196, "repetit": 198, "preserv": [199, 387], "reciproc": 201, "arr": [202, 388], "obj": 203, "uncompress": 205, "my_path": 205, "tree_flatten": [205, 246, 247, 256], "transformerencod": 205, "128": [205, 256], "flat_param": 205, "compress": 206, "possibl": [215, 266, 386, 388, 393], "being": [221, 256], "prevent": [221, 342, 390], "flow": [221, 389], "streamcontext": 222, "context": 222, "manag": [222, 383, 393], "prior": [226, 227], "exclud": 227, "dot": [230, 245, 286, 297], "rep": 231, "repeat": 231, "necessarili": 232, "elsewher": [234, 388], "col": 234, "triangl": 234, "mse": 237, "param": [237, 256, 356, 387], "lvalu": 237, "dlvalu": 237, "dparam": 237, "lasso": 237, "l1": [237, 334, 336, 337, 341], "varianc": [220, 238, 260, 269, 332], "divisor": [220, 238], "cotang": [1, 239], "in_ax": [240, 387], "out_ax": [240, 387], "prefix": [240, 245], "fn": [244, 246, 391], "wrt": 244, "is_leaf": [245, 246], "arbitrari": [245, 359], "depth": [245, 265, 387], "hello": [245, 247], "charact": 245, "flat": [245, 247], "superset": [246, 369], "extra": 246, "closer": 246, "constitut": 246, "dict_kei": [246, 370], "lambda": [246, 256, 277, 282, 289, 308, 347, 352, 362, 363, 364, 365, 366, 367, 368, 373, 374, 386, 387], "recreat": 247, "world": 247, "42": 247, "16": [250, 258, 270, 274, 277, 359], "int16": 250, "brain": 250, "e8": 250, "m7": 250, "ieee": 250, "e5": 250, "m10": 250, "hierarchi": 250, "done": [256, 263, 300, 386, 389, 390], "manual": 256, "explicitli": [256, 383], "solv": 256, "intuit": 256, "freez": [256, 294, 359], "finetun": 256, "in_dim": [256, 359], "out_dim": [256, 359], "enumer": 256, "caus": [256, 386, 389], "local": [256, 264], "scope": 256, "l2_loss": 256, "y_hat": 256, "trainable_paramet": [256, 281, 370], "loss_and_grad": 256, "workhors": 256, "Its": 256, "recurs": [256, 281, 282, 287, 292, 294, 359], "frozen": [256, 282, 292, 294, 299, 359], "individu": [256, 264, 265], "subset": [256, 281], "action": 256, "displai": 256, "tree_map": 256, "count": [256, 377], "num_param": 256, "preclud": 256, "pure": [256, 361], "pattern": [256, 389], "achiev": 256, "other_input": 256, "necessari": 256, "wrap": 256, "apply_to_modul": [256, 282], "children": 256, "filter_and_map": 256, "leaf_modul": 256, "load_weight": [256, 389], "named_modul": 256, "save_weight": 256, "set_dtyp": 256, "unfreez": [256, 282], "update_modul": 256, "alibi": 256, "avgpool1d": 256, "avgpool2d": 256, "batchnorm": 256, "conv1d": 256, "conv2d": 256, "dropout": [256, 264, 265, 291, 310, 386], "dropout2d": 256, "dropout3d": 256, "gelu": [256, 322, 323, 386], "groupnorm": 256, "gru": 256, "instancenorm": 256, "layernorm": 256, "lstm": 256, "maxpool1d": 256, "maxpool2d": [256, 259], "mish": 256, "prelu": 256, "quantizedlinear": 256, "relu": [256, 298, 310, 344, 356], "rnn": [256, 268], "selu": 256, "sequenti": [256, 356], "silu": 256, "sinusoidalpositionalencod": 256, "softshrink": 256, "upsampl": 256, "elu": [256, 347], "gelu_approx": [256, 267, 321], "gelu_fast_approx": [256, 267, 321], "hardswish": 256, "leaky_relu": 256, "log_sigmoid": 256, "log_softmax": 256, "relu6": 256, "softplu": [256, 276, 343], "tanh": [256, 268, 271, 276, 301, 343], "binary_cross_entropi": [256, 386], "cosine_similarity_loss": 256, "gaussian_nll_loss": 256, "hinge_loss": 256, "huber_loss": 256, "kl_div_loss": 256, "l1_loss": 256, "log_cosh_loss": 256, "margin_ranking_loss": 256, "mse_loss": 256, "nll_loss": 256, "smooth_l1_loss": 256, "triplet_loss": 256, "init": [256, 298, 356, 361, 375, 376, 378, 379], "uniform": [2, 256, 273, 284, 314, 316, 356, 383, 386, 387, 393], "glorot_norm": 256, "glorot_uniform": 256, "he_norm": 256, "he_uniform": 256, "kernel_s": [258, 259, 261, 262, 274, 275], "averag": [258, 259, 362, 363, 365, 366, 367], "pool": [258, 259, 274, 275, 393], "l_": [258, 274, 334], "n_i": [258, 259, 274, 275], "c_j": [258, 259, 274, 275], "ldot": [258, 259, 274, 275], "lfloor": [258, 259, 274, 275], "_size": [258, 259, 274, 275], "rfloor": [258, 259, 274, 275], "k_h": [259, 275], "k_w": [259, 275], "h_": [259, 268, 271, 275, 301], "w_": [259, 268, 271, 275, 301, 362, 363, 364, 365, 366, 367, 368, 373, 374], "height": [259, 260, 262, 264, 265, 275], "width": [259, 260, 262, 264, 265, 275, 299], "momentum": [260, 368, 370, 374, 386], "affin": [260, 269, 270, 272, 273, 299], "track_running_stat": 260, "var": [260, 269, 270, 272, 332], "epsilon": [260, 269, 270, 272, 300, 330, 332, 362, 364, 365, 366, 367, 373], "gamma": [260, 269, 270, 272, 300, 313, 314, 315, 316], "nc": 260, "nlc": [260, 261], "four": 260, "nhwc": [260, 262], "paper": [260, 307, 362, 363, 364, 365, 367, 368], "deep": [260, 313, 314, 315, 316], "intern": 260, "covari": [190, 260], "shift": 260, "bn": 260, "in_channel": [261, 262], "out_channel": [261, 262], "learnabl": [261, 262, 305], "portion": 263, "independ": [264, 265], "nwhc": 264, "whc": 264, "maintain": [264, 265, 368], "entri": [264, 265], "benefici": [264, 265, 389], "earli": 264, "adjac": 264, "pixel": 264, "thompson": 264, "goroshin": 264, "jain": 264, "lecun": 264, "bregler": 264, "2015": [264, 365, 367], "cvpr": 264, "ndhwc": 265, "dhwc": 265, "medic": 265, "video": 265, "num_embed": 266, "lookup": 266, "usual": [266, 385, 389], "vocabulari": 266, "approx": 267, "unit": [267, 268, 302, 304, 306, 313, 314, 315, 316, 320, 321, 322, 323, 324, 326, 345, 346, 347, 349], "phi": [267, 321], "geluapprox": 267, "sigma": [267, 268, 271, 313, 314, 315, 316, 322, 323, 324, 327, 348, 349], "60033": [267, 322], "0433603": [267, 322], "gelufast": 267, "773": 267, "regard": 267, "input_s": [268, 271, 301], "hidden_s": [268, 271, 301], "gate": [268, 324], "recurr": [268, 271, 301], "nld": [268, 271, 301], "ld": [268, 271, 301], "r_t": 268, "xr": 268, "x_t": [268, 271, 301], "hr": 268, "h_t": [268, 271, 301], "b_": [268, 271], "z_t": 268, "xz": 268, "hz": 268, "n_t": 268, "xn": 268, "odot": [268, 271], "hn": 268, "hidden": [268, 271, 301, 310], "nh": [268, 271, 301], "nlh": [268, 271, 301], "lh": [268, 271, 301], "num_group": 269, "pytorch_compat": 269, "split": [269, 324], "preced": 269, "org": [269, 270, 272, 276, 300, 323, 343], "1803": 269, "08494": 269, "denomin": [270, 330, 362, 364, 365, 366, 367, 373], "inorm": 270, "1607": [270, 272], "08022": 270, "i_t": 271, "xi": 271, "f_t": 271, "xf": 271, "hf": 271, "g_t": [271, 362, 364, 365, 366, 367, 368, 373, 374], "xg": 271, "hg": 271, "o_t": 271, "xo": 271, "ho": 271, "c_": [271, 368], "c_t": [271, 368], "cell": 271, "06450": 272, "mathcal": 273, "d_i": 273, "max_": [274, 275], "1908": [276, 343], "08681": [276, 343], "map_fn": [277, 281], "filter_fn": [277, 281], "valid_parameter_filt": 277, "apply_fn": 278, "descend": 279, "is_leaf_fn": 281, "found": 281, "drop": 281, "idempot": [282, 294], "endswith": 282, "file_or_weight": 284, "miss": [284, 392], "ok": [284, 387], "save_safetensor": [288, 392], "predic": 289, "reflect": [290, 386, 388, 390], "certain": [1, 291, 386], "ie": 294, "noop": 294, "unfrozen": 294, "tracer": 295, "partial": [295, 296, 386, 389], "child": 296, "flexibli": 296, "programmat": 296, "query_input_dim": 297, "key_input_dim": 297, "value_input_dim": 297, "value_dim": 297, "value_output_dim": 297, "aggreg": 297, "linearli": 297, "attend": 297, "num_paramet": 298, "25": [298, 311], "parametr": [298, 344], "classmethod": 299, "from_linear": 299, "quantize_modul": 299, "accumul": 300, "1910": 300, "07467": 300, "nonlinear": [301, 386], "elman": 301, "ih": 301, "hh": 301, "func": 301, "rectifi": [302, 315, 316, 326, 345, 346], "10000": 303, "slightli": [303, 393], "plain": 305, "known": [306, 349], "swish": [306, 349], "min_freq": 307, "0001": 307, "max_freq": 307, "cos_first": 307, "full_turn": 307, "sinusoid": 307, "sin": [307, 387, 391], "lambd": [308, 352], "threshold": [309, 334, 341, 353], "geq": [309, 353], "num_encoder_lay": 310, "num_decoder_lay": 310, "nb_func": 310, "custom_encod": 310, "custom_decod": 310, "norm_first": 310, "checkpoint": 310, "decod": 310, "interact": 310, "mechan": 310, "chekpoint": 310, "usag": [310, 386], "expens": 310, "scale_factor": 311, "nearest": 311, "align_corn": 311, "audio": 311, "4d": 311, "forth": 311, "neighbor": 311, "interpol": 311, "bilinear": 311, "trilinear": 311, "corner": 311, "bottom": 311, "squeez": [311, 386], "75": 311, "33333": 311, "66667": 311, "init_fn": [312, 313, 314, 315, 316, 317, 318, 319, 356], "glorot": [313, 314], "fan_in": [313, 314, 315, 316], "fan_out": [313, 314, 315, 316], "fan": [313, 314, 315, 316], "_in": [313, 314], "_out": [313, 314], "difficulti": [313, 314], "feedforward": [313, 314], "191107": 313, "61278": 313, "150594": 313, "363207": 313, "gain": [313, 314, 315, 316], "89613": 313, "53947": 313, "48095": 313, "995016": 313, "223404": 314, "890597": 314, "379159": 314, "776856": 314, "90041": 314, "02264": 314, "912766": 314, "12451": 314, "delv": [315, 316], "surpass": [315, 316], "human": [315, 316], "level": [315, 316], "imagenet": [315, 316], "classif": [315, 316], "25211": 315, "458835": 315, "177208": 315, "0137595": 315, "6967": 315, "02765": 315, "15268": 315, "75787": 315, "kaim": 316, "0300242": 316, "0184009": 316, "793615": 316, "666329": 316, "64331": 316, "16506": 316, "08619": 316, "79854": 316, "982273": 318, "534422": 318, "380709": 318, "0645099": 318, "883935": 319, "863726": 319, "617261": 319, "417497": 319, "exact": [322, 323], "0003": 322, "cdot": [322, 323, 330, 333, 349], "015": 323, "702": 323, "hendryck": 323, "1606": 323, "08415": 323, "halv": 324, "negative_slop": 326, "leaki": 326, "sum_i": 328, "x_i": [328, 350], "with_logit": 329, "predict": [329, 332, 333, 334, 335, 336, 337, 339, 340, 341], "105361": 329, "223144": 329, "20397": 329, "916291": 329, "539245": 329, "prob": 329, "510826": 329, "x1": 330, "x2": 330, "x_1": [330, 338], "x_2": [330, 338], "label_smooth": 331, "hot": 331, "smooth": [331, 341, 373], "0485873": 331, "348587": 331, "06": [332, 342, 362], "likelihood": [332, 340], "nll": [332, 340], "hing": 333, "y_": [333, 337], "pred": [333, 337], "delta": [334, 362], "huber": 334, "leq": [334, 347], "l2": [334, 337, 374], "kullback": 335, "leibler": 335, "diverg": 335, "cosh": 337, "logcosh": 337, "sensit": 337, "outlier": 337, "dual": 337, "behavior": [190, 337, 388, 389], "offer": 337, "balanc": 337, "robust": 337, "approach": [337, 387], "inputs1": 338, "inputs2": 338, "margin": [338, 342], "rank": 338, "573409": 338, "765166": 338, "0638": 338, "75596": 338, "225763": 338, "256995": 338, "773433": 338, "formula": 341, "anchor": 342, "triplet": 342, "_p": 342, "degre": 342, "pairwis": 342, "instabl": 342, "monoton": 343, "0507": 347, "67326": 347, "sum_j": 350, "x_j": 350, "subclass": 359, "concept": 359, "mymlp": 359, "in_proj": 359, "subsequ": 361, "apply_gradi": 361, "rmsprop": 361, "adagrad": 361, "adafactor": 361, "adadelta": 361, "adam": [361, 367, 368, 377, 378], "adamw": [361, 368], "adamax": 361, "lion": 361, "cosine_decai": [361, 377], "exponential_decai": 361, "join_schedul": 361, "linear_schedul": [361, 377], "step_decai": 361, "rho": 362, "zeiler": 362, "2012": [362, 373], "adapt": [362, 363, 364], "1212": 362, "5701": 362, "v_": [362, 364, 365, 366, 367, 373, 374], "v_t": [362, 364, 365, 366, 367, 373, 374], "u_t": 362, "u_": 362, "w_t": [362, 364, 365, 366, 367, 368, 373, 374], "001": 363, "clip_threshold": 363, "decay_r": [363, 376, 379], "beta_1": [363, 365, 366, 367, 368], "weight_decai": [363, 366, 368, 374], "scale_paramet": 363, "relative_step": 363, "warmup_init": 363, "sublinear": 363, "cost": [363, 389], "epsilon_1": 363, "epsilon_2": 363, "parameter_scal": 363, "clip": 363, "unscal": 363, "decai": [363, 366, 368, 374, 375, 376, 379], "duchi": 364, "hazan": 364, "singer": 364, "2011": 364, "subgradi": 364, "onlin": 364, "stochast": [364, 365, 367, 374, 389], "jmlr": 364, "999": [365, 366, 367], "omit": [365, 367], "estim": [365, 367], "kingma": [365, 367], "ba": [365, 367], "iclr": [365, 366, 367], "m_": [365, 366, 367, 368], "m_t": [365, 366, 367, 368], "beta_2": [365, 366, 367, 368], "contrast": 366, "loshchilov": 366, "hutter": 366, "decoupl": 366, "99": [368, 373], "tend": 368, "10x": 368, "strength": [368, 374], "wd": 368, "chen": 368, "symbol": 368, "discoveri": 368, "2302": 368, "06675": 368, "eta": 368, "opt": 369, "tieleman": 373, "hinton": 373, "lectur": 373, "coursera": 373, "dampen": 374, "nesterov": 374, "descent": [374, 386, 389], "mu": 374, "tau": 374, "penalti": 374, "decay_step": 375, "beyond": [375, 378], "minim": 375, "lr_schedul": [375, 376, 377, 379], "1000": [375, 386], "0999961": 375, "06561": 376, "boundari": 377, "join": 377, "receiv": [377, 390], "transit": 377, "warmup": [377, 378], "0999938": 377, "101": 378, "step_siz": 379, "081": 379, "basi": 381, "implicit": [383, 386, 387], "fine": [383, 389], "grain": 383, "control": [383, 389], "pseudo": 383, "altern": 383, "splittabl": 383, "threefri": 383, "counter": 383, "cycl": 385, "merg": 386, "fuse": 386, "big": 386, "awar": [386, 389], "36788": 386, "compiled_fun": 386, "code": [386, 389], "slow": 386, "stack": 386, "rerun": [386, 389], "frequent": [386, 389], "destroi": 386, "anonym": 386, "don": [386, 393], "unari": 386, "overhead": [386, 389, 393], "bandwidth": 386, "fusibl": 386, "consider": 386, "versu": 386, "timeit": [386, 387], "tic": 386, "perf_count": 386, "toc": 386, "tpi": 386, "1e3": 386, "4096": [386, 387, 393], "On": [386, 387, 389], "millisecond": [386, 393], "five": 386, "latest": 386, "won": 386, "placehold": 386, "insid": 386, "crash": 386, "disable_compil": 386, "okai": [386, 389], "intend": 386, "deal": 386, "pretti": [386, 389], "inconveni": 386, "functool": 386, "particularli": 386, "backward": [386, 387], "compiled_grad_fn": 386, "71828": 386, "outer": [386, 389], "opportun": 386, "idea": [387, 389], "behind": 387, "dfdx": [387, 388], "d2fdx2": 387, "zero_grad": 387, "detach": 387, "requires_grad": 387, "dloss_dw": 387, "dloss_dx": 387, "lot": 387, "redund": 387, "suppos": [387, 393], "nice": [387, 389], "propag": [387, 388], "stop_gradi": 387, "autom": 387, "contriv": [387, 393], "sake": 387, "clariti": 387, "quit": [387, 390], "power": [387, 390], "difficult": 387, "primit": 387, "issu": [387, 390], "priorit": 387, "naive_add": 387, "vmap_add": 387, "total": 387, "390": 387, "wherea": 387, "025": 387, "ten": [387, 389], "Of": 387, "better": [387, 393], "handi": 387, "slice": 388, "ellipsi": 388, "mix": 388, "take_along_axi": 388, "lack": 388, "extrem": [388, 389], "ineffici": [388, 389], "nonzero": 388, "dynam": 389, "easier": 389, "worri": 389, "fun1": 389, "expensive_fun": 389, "consum": 389, "eager": 389, "thank": 389, "weights_fp16": 389, "trade": 389, "bad": 389, "grow": 389, "computation": 389, "costli": 389, "wide": 389, "thousand": 389, "value_and_grad_fn": 389, "implicitli": 389, "anytim": 389, "memoryview": [389, 390], "perfectli": 389, "first_lay": 389, "second_layer_a": 389, "second_layer_b": 389, "protocol": 390, "pep": 390, "3118": 390, "a_view": 390, "owndata": 390, "extern": 390, "x_view": 390, "modifi": 390, "df": 390, "x\u00b2": 390, "2x": 390, "indirectli": 390, "modif": 390, "seen": 390, "occur": 390, "incorpor": 390, "incorrect": 390, "experiment": 390, "break": 390, "advis": 390, "intermedi": 390, "jnp": 390, "tf": 390, "page": 391, "composit": 391, "archiv": 392, "savez_compress": 392, "save_gguf": 392, "arr_0": 392, "advantag": 393, "parallel": 393, "race": 393, "interest": 393, "albeit": 393, "d1": 393, "d2": 393, "matmul": 393, "dens": [163, 393], "twice": 393, "measur": 393, "default_stream": 394, "default_devic": 394, "my_devic": 394, "explain": 1, "underli": 1, "tutori": 1, "cover": 1, "front": 1, "virtual": 1, "reimplement": 1, "finish": 1, "command_buff": 1, "mtlcommandbuff": 1, "unus": 1, "nb_modul": 1, "_ext": 1, "nb": 1, "nanobind_add_modul": 1, "nb_static": 1, "stable_abi": 1, "lto": 1, "nomins": 1, "nb_domain": 1, "extras_requir": 1, "cmake_arg": 2, "mtl_capture_en": 2, "trace_fil": 2, "mlx_trace": 2, "exit": 2, "minu": 110, "spars": 163, "xy": 163, "multidimension": 163, "coordin": 163, "cartesian": 163, "ij": 163, "successfulli": 170, "cov": 190, "jointli": 190, "semi": 190, "definit": 190, "empti": 190, "cubic": 311, "bicub": 311}, "objects": {"mlx.core": [[8, 0, 1, "", "Device"], [9, 0, 1, "", "Dtype"], [10, 0, 1, "", "DtypeCategory"], [248, 0, 1, "", "Stream"], [11, 2, 1, "", "abs"], [12, 2, 1, "", "add"], [13, 2, 1, "", "all"], [14, 2, 1, "", "allclose"], [15, 2, 1, "", "any"], [16, 2, 1, "", "arange"], [17, 2, 1, "", "arccos"], [18, 2, 1, "", "arccosh"], [19, 2, 1, "", "arcsin"], [20, 2, 1, "", "arcsinh"], [21, 2, 1, "", "arctan"], [22, 2, 1, "", "arctanh"], [23, 2, 1, "", "argmax"], [24, 2, 1, "", "argmin"], [25, 2, 1, "", "argpartition"], [26, 2, 1, "", "argsort"], [27, 0, 1, "", "array"], [76, 2, 1, "", "array_equal"], [77, 2, 1, "", "atleast_1d"], [78, 2, 1, "", "atleast_2d"], [79, 2, 1, "", "atleast_3d"], [80, 2, 1, "", "broadcast_to"], [81, 2, 1, "", "ceil"], [82, 2, 1, "", "clip"], [83, 2, 1, "", "compile"], [84, 2, 1, "", "concatenate"], [85, 2, 1, "", "conv1d"], [86, 2, 1, "", "conv2d"], [87, 2, 1, "", "conv_general"], [88, 2, 1, "", "convolve"], [89, 2, 1, "", "cos"], [90, 2, 1, "", "cosh"], [91, 2, 1, "", "cummax"], [92, 2, 1, "", "cummin"], [93, 2, 1, "", "cumprod"], [94, 2, 1, "", "cumsum"], [95, 2, 1, "", "default_device"], [96, 2, 1, "", "default_stream"], [97, 2, 1, "", "dequantize"], [98, 2, 1, "", "diag"], [99, 2, 1, "", "diagonal"], [100, 2, 1, "", "disable_compile"], [101, 2, 1, "", "divide"], [102, 2, 1, "", "divmod"], [103, 2, 1, "", "enable_compile"], [104, 2, 1, "", "equal"], [105, 2, 1, "", "erf"], [106, 2, 1, "", "erfinv"], [107, 2, 1, "", "eval"], [108, 2, 1, "", "exp"], [109, 2, 1, "", "expand_dims"], [110, 2, 1, "", "expm1"], [111, 2, 1, "", "eye"], [128, 2, 1, "", "flatten"], [129, 2, 1, "", "floor"], [130, 2, 1, "", "floor_divide"], [131, 2, 1, "", "full"], [132, 2, 1, "", "grad"], [133, 2, 1, "", "greater"], [134, 2, 1, "", "greater_equal"], [135, 2, 1, "", "identity"], [136, 2, 1, "", "inner"], [137, 2, 1, "", "isclose"], [138, 2, 1, "", "isinf"], [139, 2, 1, "", "isnan"], [140, 2, 1, "", "isneginf"], [141, 2, 1, "", "isposinf"], [142, 2, 1, "", "issubdtype"], [143, 2, 1, "", "jvp"], [144, 2, 1, "", "less"], [145, 2, 1, "", "less_equal"], [148, 2, 1, "", "linspace"], [149, 2, 1, "", "load"], [150, 2, 1, "", "log"], [151, 2, 1, "", "log10"], [152, 2, 1, "", "log1p"], [153, 2, 1, "", "log2"], [154, 2, 1, "", "logaddexp"], [155, 2, 1, "", "logical_and"], [156, 2, 1, "", "logical_not"], [157, 2, 1, "", "logical_or"], [158, 2, 1, "", "logsumexp"], [159, 2, 1, "", "matmul"], [160, 2, 1, "", "max"], [161, 2, 1, "", "maximum"], [162, 2, 1, "", "mean"], [163, 2, 1, "", "meshgrid"], [172, 2, 1, "", "min"], [173, 2, 1, "", "minimum"], [174, 2, 1, "", "moveaxis"], [175, 2, 1, "", "multiply"], [176, 2, 1, "", "negative"], [177, 2, 1, "", "new_stream"], [178, 2, 1, "", "ones"], [179, 2, 1, "", "ones_like"], [180, 2, 1, "", "outer"], [181, 2, 1, "", "pad"], [182, 2, 1, "", "partition"], [183, 2, 1, "", "prod"], [184, 2, 1, "", "quantize"], [185, 2, 1, "", "quantized_matmul"], [197, 2, 1, "", "reciprocal"], [198, 2, 1, "", "repeat"], [199, 2, 1, "", "reshape"], [200, 2, 1, "", "round"], [201, 2, 1, "", "rsqrt"], [202, 2, 1, "", "save"], [203, 2, 1, "", "save_gguf"], [204, 2, 1, "", "save_safetensors"], [205, 2, 1, "", "savez"], [206, 2, 1, "", "savez_compressed"], [207, 2, 1, "", "set_default_device"], [208, 2, 1, "", "set_default_stream"], [209, 2, 1, "", "sigmoid"], [210, 2, 1, "", "sign"], [211, 2, 1, "", "sin"], [212, 2, 1, "", "sinh"], [213, 2, 1, "", "softmax"], [214, 2, 1, "", "sort"], [215, 2, 1, "", "split"], [216, 2, 1, "", "sqrt"], [217, 2, 1, "", "square"], [218, 2, 1, "", "squeeze"], [219, 2, 1, "", "stack"], [220, 2, 1, "", "std"], [221, 2, 1, "", "stop_gradient"], [222, 2, 1, "", "stream"], [223, 2, 1, "", "subtract"], [224, 2, 1, "", "sum"], [225, 2, 1, "", "swapaxes"], [226, 2, 1, "", "take"], [227, 2, 1, "", "take_along_axis"], [228, 2, 1, "", "tan"], [229, 2, 1, "", "tanh"], [230, 2, 1, "", "tensordot"], [231, 2, 1, "", "tile"], [232, 2, 1, "", "topk"], [233, 2, 1, "", "transpose"], [234, 2, 1, "", "tri"], [235, 2, 1, "", "tril"], [236, 2, 1, "", "triu"], [237, 2, 1, "", "value_and_grad"], [238, 2, 1, "", "var"], [239, 2, 1, "", "vjp"], [240, 2, 1, "", "vmap"], [241, 2, 1, "", "where"], [242, 2, 1, "", "zeros"], [243, 2, 1, "", "zeros_like"]], "mlx.core.Device": [[8, 1, 1, "", "__init__"]], "mlx.core.Dtype": [[9, 1, 1, "", "__init__"]], "mlx.core.DtypeCategory": [[10, 1, 1, "", "__init__"]], "mlx.core.Stream": [[248, 1, 1, "", "__init__"]], "mlx.core.array": [[28, 3, 1, "", "T"], [27, 1, 1, "", "__init__"], [29, 1, 1, "", "abs"], [30, 1, 1, "", "all"], [31, 1, 1, "", "any"], [32, 1, 1, "", "argmax"], [33, 1, 1, "", "argmin"], [34, 1, 1, "", "astype"], [35, 3, 1, "", "at"], [36, 1, 1, "", "cos"], [37, 1, 1, "", "cummax"], [38, 1, 1, "", "cummin"], [39, 1, 1, "", "cumprod"], [40, 1, 1, "", "cumsum"], [41, 1, 1, "", "diag"], [42, 1, 1, "", "diagonal"], [43, 3, 1, "", "dtype"], [44, 1, 1, "", "exp"], [45, 1, 1, "", "flatten"], [46, 1, 1, "", "item"], [47, 3, 1, "", "itemsize"], [48, 1, 1, "", "log"], [49, 1, 1, "", "log10"], [50, 1, 1, "", "log1p"], [51, 1, 1, "", "log2"], [52, 1, 1, "", "logsumexp"], [53, 1, 1, "", "max"], [54, 1, 1, "", "mean"], [55, 1, 1, "", "min"], [56, 1, 1, "", "moveaxis"], [57, 3, 1, "", "nbytes"], [58, 3, 1, "", "ndim"], [59, 1, 1, "", "prod"], [60, 1, 1, "", "reciprocal"], [61, 1, 1, "", "reshape"], [62, 1, 1, "", "round"], [63, 1, 1, "", "rsqrt"], [64, 3, 1, "", "shape"], [65, 1, 1, "", "sin"], [66, 3, 1, "", "size"], [67, 1, 1, "", "split"], [68, 1, 1, "", "sqrt"], [69, 1, 1, "", "square"], [70, 1, 1, "", "squeeze"], [71, 1, 1, "", "sum"], [72, 1, 1, "", "swapaxes"], [73, 1, 1, "", "tolist"], [74, 1, 1, "", "transpose"], [75, 1, 1, "", "var"]], "mlx.core.fast": [[112, 2, 1, "", "layer_norm"], [113, 2, 1, "", "rms_norm"], [114, 2, 1, "", "rope"], [115, 2, 1, "", "scaled_dot_product_attention"]], "mlx.core.fft": [[116, 2, 1, "", "fft"], [117, 2, 1, "", "fft2"], [118, 2, 1, "", "fftn"], [119, 2, 1, "", "ifft"], [120, 2, 1, "", "ifft2"], [121, 2, 1, "", "ifftn"], [122, 2, 1, "", "irfft"], [123, 2, 1, "", "irfft2"], [124, 2, 1, "", "irfftn"], [125, 2, 1, "", "rfft"], [126, 2, 1, "", "rfft2"], [127, 2, 1, "", "rfftn"]], "mlx.core.linalg": [[146, 2, 1, "", "norm"], [147, 2, 1, "", "qr"]], "mlx.core.metal": [[164, 2, 1, "", "get_active_memory"], [165, 2, 1, "", "get_cache_memory"], [166, 2, 1, "", "get_peak_memory"], [167, 2, 1, "", "is_available"], [168, 2, 1, "", "set_cache_limit"], [169, 2, 1, "", "set_memory_limit"], [170, 2, 1, "", "start_capture"], [171, 2, 1, "", "stop_capture"]], "mlx.core.random": [[186, 2, 1, "", "bernoulli"], [187, 2, 1, "", "categorical"], [188, 2, 1, "", "gumbel"], [189, 2, 1, "", "key"], [190, 2, 1, "", "multivariate_normal"], [191, 2, 1, "", "normal"], [192, 2, 1, "", "randint"], [193, 2, 1, "", "seed"], [194, 2, 1, "", "split"], [195, 2, 1, "", "truncated_normal"], [196, 2, 1, "", "uniform"]], "mlx.nn": [[257, 0, 1, "", "ALiBi"], [258, 0, 1, "", "AvgPool1d"], [259, 0, 1, "", "AvgPool2d"], [260, 0, 1, "", "BatchNorm"], [261, 0, 1, "", "Conv1d"], [262, 0, 1, "", "Conv2d"], [263, 0, 1, "", "Dropout"], [264, 0, 1, "", "Dropout2d"], [265, 0, 1, "", "Dropout3d"], [266, 0, 1, "", "Embedding"], [267, 0, 1, "", "GELU"], [268, 0, 1, "", "GRU"], [269, 0, 1, "", "GroupNorm"], [270, 0, 1, "", "InstanceNorm"], [271, 0, 1, "", "LSTM"], [272, 0, 1, "", "LayerNorm"], [273, 0, 1, "", "Linear"], [274, 0, 1, "", "MaxPool1d"], [275, 0, 1, "", "MaxPool2d"], [276, 0, 1, "", "Mish"], [359, 0, 1, "", "Module"], [297, 0, 1, "", "MultiHeadAttention"], [298, 0, 1, "", "PReLU"], [299, 0, 1, "", "QuantizedLinear"], [300, 0, 1, "", "RMSNorm"], [301, 0, 1, "", "RNN"], [302, 0, 1, "", "ReLU"], [303, 0, 1, "", "RoPE"], [304, 0, 1, "", "SELU"], [305, 0, 1, "", "Sequential"], [306, 0, 1, "", "SiLU"], [307, 0, 1, "", "SinusoidalPositionalEncoding"], [308, 0, 1, "", "Softshrink"], [309, 0, 1, "", "Step"], [310, 0, 1, "", "Transformer"], [311, 0, 1, "", "Upsample"], [320, 2, 1, "", "elu"], [321, 2, 1, "", "gelu"], [322, 2, 1, "", "gelu_approx"], [323, 2, 1, "", "gelu_fast_approx"], [324, 2, 1, "", "glu"], [325, 2, 1, "", "hardswish"], [326, 2, 1, "", "leaky_relu"], [327, 2, 1, "", "log_sigmoid"], [328, 2, 1, "", "log_softmax"], [343, 2, 1, "", "mish"], [344, 2, 1, "", "prelu"], [345, 2, 1, "", "relu"], [346, 2, 1, "", "relu6"], [347, 2, 1, "", "selu"], [348, 2, 1, "", "sigmoid"], [349, 2, 1, "", "silu"], [350, 2, 1, "", "softmax"], [351, 2, 1, "", "softplus"], [352, 2, 1, "", "softshrink"], [353, 2, 1, "", "step"], [354, 2, 1, "", "tanh"], [244, 2, 1, "", "value_and_grad"]], "mlx.nn.Module": [[277, 1, 1, "", "apply"], [278, 1, 1, "", "apply_to_modules"], [279, 1, 1, "", "children"], [280, 1, 1, "", "eval"], [281, 1, 1, "", "filter_and_map"], [282, 1, 1, "", "freeze"], [283, 1, 1, "", "leaf_modules"], [284, 1, 1, "", "load_weights"], [285, 1, 1, "", "modules"], [286, 1, 1, "", "named_modules"], [287, 1, 1, "", "parameters"], [288, 1, 1, "", "save_weights"], [289, 1, 1, "", "set_dtype"], [290, 3, 1, "", "state"], [291, 1, 1, "", "train"], [292, 1, 1, "", "trainable_parameters"], [293, 3, 1, "", "training"], [294, 1, 1, "", "unfreeze"], [295, 1, 1, "", "update"], [296, 1, 1, "", "update_modules"]], "mlx.nn.init": [[312, 2, 1, "", "constant"], [313, 2, 1, "", "glorot_normal"], [314, 2, 1, "", "glorot_uniform"], [315, 2, 1, "", "he_normal"], [316, 2, 1, "", "he_uniform"], [317, 2, 1, "", "identity"], [318, 2, 1, "", "normal"], [319, 2, 1, "", "uniform"]], "mlx.nn.losses": [[329, 2, 1, "", "binary_cross_entropy"], [330, 2, 1, "", "cosine_similarity_loss"], [331, 2, 1, "", "cross_entropy"], [332, 2, 1, "", "gaussian_nll_loss"], [333, 2, 1, "", "hinge_loss"], [334, 2, 1, "", "huber_loss"], [335, 2, 1, "", "kl_div_loss"], [336, 2, 1, "", "l1_loss"], [337, 2, 1, "", "log_cosh_loss"], [338, 2, 1, "", "margin_ranking_loss"], [339, 2, 1, "", "mse_loss"], [340, 2, 1, "", "nll_loss"], [341, 2, 1, "", "smooth_l1_loss"], [342, 2, 1, "", "triplet_loss"]], "mlx.optimizers": [[362, 0, 1, "", "AdaDelta"], [363, 0, 1, "", "Adafactor"], [364, 0, 1, "", "Adagrad"], [365, 0, 1, "", "Adam"], [366, 0, 1, "", "AdamW"], [367, 0, 1, "", "Adamax"], [368, 0, 1, "", "Lion"], [381, 0, 1, "", "Optimizer"], [373, 0, 1, "", "RMSprop"], [374, 0, 1, "", "SGD"], [375, 2, 1, "", "cosine_decay"], [376, 2, 1, "", "exponential_decay"], [377, 2, 1, "", "join_schedules"], [378, 2, 1, "", "linear_schedule"], [379, 2, 1, "", "step_decay"]], "mlx.optimizers.Optimizer": [[369, 1, 1, "", "apply_gradients"], [370, 1, 1, "", "init"], [371, 3, 1, "", "state"], [372, 1, 1, "", "update"]], "mlx.utils": [[245, 2, 1, "", "tree_flatten"], [246, 2, 1, "", "tree_map"], [247, 2, 1, "", "tree_unflatten"]]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function", "3": "py:property"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"], "3": ["py", "property", "Python property"]}, "titleterms": {"oper": [0, 1, 360], "develop": 1, "document": 1, "introduc": 1, "exampl": [1, 6, 386, 393], "primit": 1, "us": [1, 389, 394], "implement": [1, 4], "cpu": 1, "backend": [], "gpu": 1, "transform": [1, 310, 384, 386, 387, 389, 391], "build": [1, 7], "bind": 1, "python": [1, 6, 7], "cmake": 1, "setuptool": 1, "usag": [1, 6], "result": 1, "script": [1, 4], "download": [1, 4], "code": [1, 4], "metal": [2, 7, 164, 165, 166, 167, 168, 169, 170, 171, 255], "debugg": 2, "xcode": 2, "workflow": 2, "linear": [3, 254, 273], "regress": 3, "llm": 4, "infer": 4, "model": 4, "attent": 4, "layer": [4, 5, 357], "encod": 4, "full": [4, 131], "gener": 4, "put": 4, "all": [4, 13, 30], "togeth": 4, "convert": 4, "weight": 4, "load": [4, 149, 392], "benchmark": 4, "multi": 5, "perceptron": 5, "mlx": [6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379], "instal": [6, 7], "api": [6, 7], "refer": 6, "c": [6, 7], "further": 6, "read": 6, "troubleshoot": 7, "from": [7, 388], "sourc": 7, "requir": 7, "option": 7, "found": 7, "x86": 7, "shell": 7, "core": [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 248], "devic": [8, 251], "dtype": [9, 43], "dtypecategori": 10, "ab": [11, 29], "add": 12, "allclos": 14, "ani": [15, 31], "arang": 16, "arcco": 17, "arccosh": 18, "arcsin": 19, "arcsinh": 20, "arctan": 21, "arctanh": 22, "argmax": [23, 32], "argmin": [24, 33], "argpartit": 25, "argsort": 26, "arrai": [27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 249, 388, 392], "t": 28, "astyp": 34, "co": [36, 89], "cummax": [37, 91], "cummin": [38, 92], "cumprod": [39, 93], "cumsum": [40, 94], "diag": [41, 98], "diagon": [42, 99], "exp": [44, 108], "flatten": [45, 128], "item": 46, "items": 47, "log": [48, 150], "log10": [49, 151], "log1p": [50, 152], "log2": [51, 153], "logsumexp": [52, 158], "max": [53, 160], "mean": [54, 162], "min": [55, 172], "moveaxi": [56, 174], "nbyte": 57, "ndim": 58, "prod": [59, 183], "reciproc": [60, 197], "reshap": [61, 199], "round": [62, 200], "rsqrt": [63, 201], "shape": 64, "sin": [65, 211], "size": 66, "split": [67, 194, 215], "sqrt": [68, 216], "squar": [69, 217], "squeez": [70, 218], "sum": [71, 224], "swapax": [72, 225], "tolist": 73, "transpos": [74, 233], "var": [75, 238], "array_equ": 76, "atleast_1d": 77, "atleast_2d": 78, "atleast_3d": 79, "broadcast_to": 80, "ceil": 81, "clip": 82, "compil": [83, 386], "concaten": 84, "conv1d": [85, 261], "conv2d": [86, 262], "conv_gener": 87, "convolv": 88, "cosh": 90, "default_devic": 95, "default_stream": 96, "dequant": 97, "disable_compil": 100, "divid": 101, "divmod": 102, "enable_compil": 103, "equal": 104, "erf": 105, "erfinv": 106, "eval": [107, 280], "expand_dim": 109, "ey": 111, "fast": [112, 113, 114, 115, 252], "layer_norm": 112, "rms_norm": 113, "rope": [114, 303], "scaled_dot_product_attent": 115, "fft": [116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 253], "fft2": 117, "fftn": 118, "ifft": 119, "ifft2": 120, "ifftn": 121, "irfft": 122, "irfft2": 123, "irfftn": 124, "rfft": 125, "rfft2": 126, "rfftn": 127, "floor": 129, "floor_divid": 130, "grad": [132, 256], "greater": 133, "greater_equ": 134, "ident": [135, 317], "inner": 136, "isclos": 137, "isinf": 138, "isnan": 139, "isneginf": 140, "isposinf": 141, "issubdtyp": 142, "jvp": 143, "less": 144, "less_equ": 145, "linalg": [146, 147], "norm": 146, "qr": 147, "linspac": 148, "logaddexp": 154, "logical_and": 155, "logical_not": 156, "logical_or": 157, "matmul": 159, "maximum": 161, "get_active_memori": 164, "get_cache_memori": 165, "get_peak_memori": 166, "is_avail": 167, "set_cache_limit": 168, "set_memory_limit": 169, "minimum": 173, "multipli": 175, "neg": 176, "new_stream": 177, "ones": 178, "ones_lik": 179, "outer": 180, "pad": 181, "partit": 182, "quantiz": 184, "quantized_matmul": 185, "random": [186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 383], "bernoulli": 186, "categor": 187, "gumbel": 188, "kei": 189, "normal": [191, 318], "randint": 192, "seed": 193, "truncated_norm": 195, "uniform": [196, 319], "repeat": 198, "save": [202, 392], "save_gguf": 203, "save_safetensor": 204, "savez": 205, "savez_compress": 206, "set_default_devic": 207, "set_default_stream": 208, "sigmoid": [209, 348], "sign": 210, "sinh": 212, "softmax": [213, 350], "sort": 214, "stack": 219, "stop_gradi": 221, "stream": [222, 248, 251, 394], "subtract": 223, "take": 226, "take_along_axi": 227, "tan": 228, "tanh": [229, 354], "tensordot": 230, "tile": 231, "topk": 232, "tri": 234, "tril": 235, "triu": 236, "value_and_grad": [237, 244], "vjp": 239, "vmap": 240, "where": 241, "zero": 242, "zeros_lik": 243, "nn": [244, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354], "util": [245, 246, 247, 385], "tree_flatten": 245, "tree_map": 246, "tree_unflatten": 247, "data": 250, "type": 250, "support": 250, "algebra": 254, "neural": 256, "network": 256, "quick": [256, 391], "start": [256, 391], "The": 256, "modul": [256, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 359], "class": 256, "paramet": [256, 287], "updat": [256, 295, 372, 388], "inspect": 256, "valu": 256, "alibi": 257, "avgpool1d": 258, "avgpool2d": 259, "batchnorm": 260, "dropout": 263, "dropout2d": 264, "dropout3d": 265, "embed": 266, "gelu": [267, 321], "gru": 268, "groupnorm": 269, "instancenorm": 270, "lstm": 271, "layernorm": 272, "maxpool1d": 274, "maxpool2d": 275, "mish": [276, 343], "appli": 277, "apply_to_modul": 278, "children": 279, "filter_and_map": 281, "freez": 282, "leaf_modul": 283, "load_weight": 284, "named_modul": 286, "save_weight": 288, "set_dtyp": 289, "state": [290, 371], "train": [291, 293, 386], "trainable_paramet": 292, "unfreez": 294, "update_modul": 296, "multiheadattent": 297, "prelu": [298, 344], "quantizedlinear": 299, "rmsnorm": 300, "rnn": 301, "relu": [302, 345], "selu": [304, 347], "sequenti": 305, "silu": [306, 349], "sinusoidalpositionalencod": 307, "softshrink": [308, 352], "step": [309, 353], "upsampl": 311, "init": [312, 313, 314, 315, 316, 317, 318, 319, 370], "constant": 312, "glorot_norm": 313, "glorot_uniform": 314, "he_norm": 315, "he_uniform": 316, "elu": 320, "gelu_approx": 322, "gelu_fast_approx": 323, "glu": 324, "hardswish": 325, "leaky_relu": 326, "log_sigmoid": 327, "log_softmax": 328, "loss": [329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 358], "binary_cross_entropi": 329, "cosine_similarity_loss": 330, "cross_entropi": 331, "gaussian_nll_loss": 332, "hinge_loss": 333, "huber_loss": 334, "kl_div_loss": 335, "l1_loss": 336, "log_cosh_loss": 337, "margin_ranking_loss": 338, "mse_loss": 339, "nll_loss": 340, "smooth_l1_loss": 341, "triplet_loss": 342, "relu6": 346, "softplu": 351, "function": [355, 358, 386, 387, 391], "initi": 356, "optim": [361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381], "adadelta": 362, "adafactor": 363, "adagrad": 364, "adam": 365, "adamw": 366, "adamax": 367, "lion": 368, "apply_gradi": 369, "rmsprop": 373, "sgd": 374, "cosine_decai": 375, "exponential_decai": 376, "join_schedul": 377, "linear_schedul": 378, "step_decai": 379, "common": 380, "schedul": 382, "tree": 385, "basic": [386, 391], "speedup": 386, "debug": 386, "pure": 386, "graph": [386, 389, 391], "automat": 387, "differenti": 387, "vector": 387, "index": 388, "differ": 388, "numpi": [388, 390], "In": 388, "place": 388, "lazi": 389, "evalu": 389, "why": 389, "comput": 389, "onli": 389, "what": 389, "you": 389, "when": 389, "convers": 390, "other": 390, "framework": 390, "pytorch": 390, "jax": 390, "tensorflow": 390, "guid": 391, "serial": 392, "format": 392, "unifi": 393, "memori": 393, "A": 393, "simpl": 393, "specifi": 394, "back": 1, "end": 1, "expm1": 110, "meshgrid": 163, "start_captur": 170, "stop_captur": 171, "multivariate_norm": 190, "std": 220}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 60}, "alltitles": {"Operations": [[0, "operations"], [1, "operations"], [360, "operations"]], "Developer Documentation": [[1, "developer-documentation"]], "Introducing the Example": [[1, "introducing-the-example"]], "Operations and Primitives": [[1, "operations-and-primitives"]], "Primitives": [[1, "primitives"]], "Using the Primitive": [[1, "using-the-primitive"]], "Implementing the Primitive": [[1, "implementing-the-primitive"]], "Implementing the CPU Back-end": [[1, "implementing-the-cpu-back-end"]], "Implementing the GPU Back-end": [[1, "implementing-the-gpu-back-end"]], "Primitive Transforms": [[1, "primitive-transforms"]], "Building and Binding": [[1, "building-and-binding"]], "Binding to Python": [[1, "binding-to-python"]], "Building with CMake": [[1, "building-with-cmake"]], "Building with setuptools": [[1, "building-with-setuptools"]], "Usage": [[1, "usage"], [6, null]], "Results": [[1, "results"]], "Scripts": [[1, "scripts"], [4, "scripts"]], "Download the code": [[1, null], [4, null]], "Metal Debugger": [[2, "metal-debugger"]], "Xcode Workflow": [[2, "xcode-workflow"]], "Linear Regression": [[3, "linear-regression"]], "LLM inference": [[4, "llm-inference"]], "Implementing the model": [[4, "implementing-the-model"]], "Attention layer": [[4, "attention-layer"]], "Encoder layer": [[4, "encoder-layer"]], "Full model": [[4, "full-model"]], "Generation": [[4, "generation"]], "Putting it all together": [[4, "putting-it-all-together"]], "Converting the weights": [[4, "converting-the-weights"]], "Weight loading and benchmarking": [[4, "weight-loading-and-benchmarking"]], "Multi-Layer Perceptron": [[5, "multi-layer-perceptron"]], "MLX": [[6, "mlx"]], "Install": [[6, null]], "Examples": [[6, null]], "Python API Reference": [[6, null]], "C++ API Reference": [[6, null]], "Further Reading": [[6, null]], "Build and Install": [[7, "build-and-install"]], "Python Installation": [[7, "python-installation"]], "Troubleshooting": [[7, "troubleshooting"], [7, "id2"]], "Build from source": [[7, "build-from-source"]], "Build Requirements": [[7, "build-requirements"]], "Python API": [[7, "python-api"]], "C++ API": [[7, "c-api"]], "Build Options": [[7, "id3"]], "Metal not found": [[7, "metal-not-found"]], "x86 Shell": [[7, "x86-shell"]], "mlx.core.Device": [[8, "mlx-core-device"]], "mlx.core.Dtype": [[9, "mlx-core-dtype"]], "mlx.core.DtypeCategory": [[10, "mlx-core-dtypecategory"]], "mlx.core.abs": [[11, "mlx-core-abs"]], "mlx.core.add": [[12, "mlx-core-add"]], "mlx.core.all": [[13, "mlx-core-all"]], "mlx.core.allclose": [[14, "mlx-core-allclose"]], "mlx.core.any": [[15, "mlx-core-any"]], "mlx.core.arange": [[16, "mlx-core-arange"]], "mlx.core.arccos": [[17, "mlx-core-arccos"]], "mlx.core.arccosh": [[18, "mlx-core-arccosh"]], "mlx.core.arcsin": [[19, "mlx-core-arcsin"]], "mlx.core.arcsinh": [[20, "mlx-core-arcsinh"]], "mlx.core.arctan": [[21, "mlx-core-arctan"]], "mlx.core.arctanh": [[22, "mlx-core-arctanh"]], "mlx.core.argmax": [[23, "mlx-core-argmax"]], "mlx.core.argmin": [[24, "mlx-core-argmin"]], "mlx.core.argpartition": [[25, "mlx-core-argpartition"]], "mlx.core.argsort": [[26, "mlx-core-argsort"]], "mlx.core.array": [[27, "mlx-core-array"]], "mlx.core.array.T": [[28, "mlx-core-array-t"]], "mlx.core.array.abs": [[29, "mlx-core-array-abs"]], "mlx.core.array.all": [[30, "mlx-core-array-all"]], "mlx.core.array.any": [[31, "mlx-core-array-any"]], "mlx.core.array.argmax": [[32, "mlx-core-array-argmax"]], "mlx.core.array.argmin": [[33, "mlx-core-array-argmin"]], "mlx.core.array.astype": [[34, "mlx-core-array-astype"]], "mlx.core.array.at": [[35, "mlx-core-array-at"]], "mlx.core.array.cos": [[36, "mlx-core-array-cos"]], "mlx.core.array.cummax": [[37, "mlx-core-array-cummax"]], "mlx.core.array.cummin": [[38, "mlx-core-array-cummin"]], "mlx.core.array.cumprod": [[39, "mlx-core-array-cumprod"]], "mlx.core.array.cumsum": [[40, "mlx-core-array-cumsum"]], "mlx.core.array.diag": [[41, "mlx-core-array-diag"]], "mlx.core.array.diagonal": [[42, "mlx-core-array-diagonal"]], "mlx.core.array.dtype": [[43, "mlx-core-array-dtype"]], "mlx.core.array.exp": [[44, "mlx-core-array-exp"]], "mlx.core.array.flatten": [[45, "mlx-core-array-flatten"]], "mlx.core.array.item": [[46, "mlx-core-array-item"]], "mlx.core.array.itemsize": [[47, "mlx-core-array-itemsize"]], "mlx.core.array.log": [[48, "mlx-core-array-log"]], "mlx.core.array.log10": [[49, "mlx-core-array-log10"]], "mlx.core.array.log1p": [[50, "mlx-core-array-log1p"]], "mlx.core.array.log2": [[51, "mlx-core-array-log2"]], "mlx.core.array.logsumexp": [[52, "mlx-core-array-logsumexp"]], "mlx.core.array.max": [[53, "mlx-core-array-max"]], "mlx.core.array.mean": [[54, "mlx-core-array-mean"]], "mlx.core.array.min": [[55, "mlx-core-array-min"]], "mlx.core.array.moveaxis": [[56, "mlx-core-array-moveaxis"]], "mlx.core.array.nbytes": [[57, "mlx-core-array-nbytes"]], "mlx.core.array.ndim": [[58, "mlx-core-array-ndim"]], "mlx.core.array.prod": [[59, "mlx-core-array-prod"]], "mlx.core.array.reciprocal": [[60, "mlx-core-array-reciprocal"]], "mlx.core.array.reshape": [[61, "mlx-core-array-reshape"]], "mlx.core.array.round": [[62, "mlx-core-array-round"]], "mlx.core.array.rsqrt": [[63, "mlx-core-array-rsqrt"]], "mlx.core.array.shape": [[64, "mlx-core-array-shape"]], "mlx.core.array.sin": [[65, "mlx-core-array-sin"]], "mlx.core.array.size": [[66, "mlx-core-array-size"]], "mlx.core.array.split": [[67, "mlx-core-array-split"]], "mlx.core.array.sqrt": [[68, "mlx-core-array-sqrt"]], "mlx.core.array.square": [[69, "mlx-core-array-square"]], "mlx.core.array.squeeze": [[70, "mlx-core-array-squeeze"]], "mlx.core.array.sum": [[71, "mlx-core-array-sum"]], "mlx.core.array.swapaxes": [[72, "mlx-core-array-swapaxes"]], "mlx.core.array.tolist": [[73, "mlx-core-array-tolist"]], "mlx.core.array.transpose": [[74, "mlx-core-array-transpose"]], "mlx.core.array.var": [[75, "mlx-core-array-var"]], "mlx.core.array_equal": [[76, "mlx-core-array-equal"]], "mlx.core.atleast_1d": [[77, "mlx-core-atleast-1d"]], "mlx.core.atleast_2d": [[78, "mlx-core-atleast-2d"]], "mlx.core.atleast_3d": [[79, "mlx-core-atleast-3d"]], "mlx.core.broadcast_to": [[80, "mlx-core-broadcast-to"]], "mlx.core.ceil": [[81, "mlx-core-ceil"]], "mlx.core.clip": [[82, "mlx-core-clip"]], "mlx.core.compile": [[83, "mlx-core-compile"]], "mlx.core.concatenate": [[84, "mlx-core-concatenate"]], "mlx.core.conv1d": [[85, "mlx-core-conv1d"]], "mlx.core.conv2d": [[86, "mlx-core-conv2d"]], "mlx.core.conv_general": [[87, "mlx-core-conv-general"]], "mlx.core.convolve": [[88, "mlx-core-convolve"]], "mlx.core.cos": [[89, "mlx-core-cos"]], "mlx.core.cosh": [[90, "mlx-core-cosh"]], "mlx.core.cummax": [[91, "mlx-core-cummax"]], "mlx.core.cummin": [[92, "mlx-core-cummin"]], "mlx.core.cumprod": [[93, "mlx-core-cumprod"]], "mlx.core.cumsum": [[94, "mlx-core-cumsum"]], "mlx.core.default_device": [[95, "mlx-core-default-device"]], "mlx.core.default_stream": [[96, "mlx-core-default-stream"]], "mlx.core.dequantize": [[97, "mlx-core-dequantize"]], "mlx.core.diag": [[98, "mlx-core-diag"]], "mlx.core.diagonal": [[99, "mlx-core-diagonal"]], "mlx.core.disable_compile": [[100, "mlx-core-disable-compile"]], "mlx.core.divide": [[101, "mlx-core-divide"]], "mlx.core.divmod": [[102, "mlx-core-divmod"]], "mlx.core.enable_compile": [[103, "mlx-core-enable-compile"]], "mlx.core.equal": [[104, "mlx-core-equal"]], "mlx.core.erf": [[105, "mlx-core-erf"]], "mlx.core.erfinv": [[106, "mlx-core-erfinv"]], "mlx.core.eval": [[107, "mlx-core-eval"]], "mlx.core.exp": [[108, "mlx-core-exp"]], "mlx.core.expand_dims": [[109, "mlx-core-expand-dims"]], "mlx.core.expm1": [[110, "mlx-core-expm1"]], "mlx.core.eye": [[111, "mlx-core-eye"]], "mlx.core.fast.layer_norm": [[112, "mlx-core-fast-layer-norm"]], "mlx.core.fast.rms_norm": [[113, "mlx-core-fast-rms-norm"]], "mlx.core.fast.rope": [[114, "mlx-core-fast-rope"]], "mlx.core.fast.scaled_dot_product_attention": [[115, "mlx-core-fast-scaled-dot-product-attention"]], "mlx.core.fft.fft": [[116, "mlx-core-fft-fft"]], "mlx.core.fft.fft2": [[117, "mlx-core-fft-fft2"]], "mlx.core.fft.fftn": [[118, "mlx-core-fft-fftn"]], "mlx.core.fft.ifft": [[119, "mlx-core-fft-ifft"]], "mlx.core.fft.ifft2": [[120, "mlx-core-fft-ifft2"]], "mlx.core.fft.ifftn": [[121, "mlx-core-fft-ifftn"]], "mlx.core.fft.irfft": [[122, "mlx-core-fft-irfft"]], "mlx.core.fft.irfft2": [[123, "mlx-core-fft-irfft2"]], "mlx.core.fft.irfftn": [[124, "mlx-core-fft-irfftn"]], "mlx.core.fft.rfft": [[125, "mlx-core-fft-rfft"]], "mlx.core.fft.rfft2": [[126, "mlx-core-fft-rfft2"]], "mlx.core.fft.rfftn": [[127, "mlx-core-fft-rfftn"]], "mlx.core.flatten": [[128, "mlx-core-flatten"]], "mlx.core.floor": [[129, "mlx-core-floor"]], "mlx.core.floor_divide": [[130, "mlx-core-floor-divide"]], "mlx.core.full": [[131, "mlx-core-full"]], "mlx.core.grad": [[132, "mlx-core-grad"]], "mlx.core.greater": [[133, "mlx-core-greater"]], "mlx.core.greater_equal": [[134, "mlx-core-greater-equal"]], "mlx.core.identity": [[135, "mlx-core-identity"]], "mlx.core.inner": [[136, "mlx-core-inner"]], "mlx.core.isclose": [[137, "mlx-core-isclose"]], "mlx.core.isinf": [[138, "mlx-core-isinf"]], "mlx.core.isnan": [[139, "mlx-core-isnan"]], "mlx.core.isneginf": [[140, "mlx-core-isneginf"]], "mlx.core.isposinf": [[141, "mlx-core-isposinf"]], "mlx.core.issubdtype": [[142, "mlx-core-issubdtype"]], "mlx.core.jvp": [[143, "mlx-core-jvp"]], "mlx.core.less": [[144, "mlx-core-less"]], "mlx.core.less_equal": [[145, "mlx-core-less-equal"]], "mlx.core.linalg.norm": [[146, "mlx-core-linalg-norm"]], "mlx.core.linalg.qr": [[147, "mlx-core-linalg-qr"]], "mlx.core.linspace": [[148, "mlx-core-linspace"]], "mlx.core.load": [[149, "mlx-core-load"]], "mlx.core.log": [[150, "mlx-core-log"]], "mlx.core.log10": [[151, "mlx-core-log10"]], "mlx.core.log1p": [[152, "mlx-core-log1p"]], "mlx.core.log2": [[153, "mlx-core-log2"]], "mlx.core.logaddexp": [[154, "mlx-core-logaddexp"]], "mlx.core.logical_and": [[155, "mlx-core-logical-and"]], "mlx.core.logical_not": [[156, "mlx-core-logical-not"]], "mlx.core.logical_or": [[157, "mlx-core-logical-or"]], "mlx.core.logsumexp": [[158, "mlx-core-logsumexp"]], "mlx.core.matmul": [[159, "mlx-core-matmul"]], "mlx.core.max": [[160, "mlx-core-max"]], "mlx.core.maximum": [[161, "mlx-core-maximum"]], "mlx.core.mean": [[162, "mlx-core-mean"]], "mlx.core.meshgrid": [[163, "mlx-core-meshgrid"]], "mlx.core.metal.get_active_memory": [[164, "mlx-core-metal-get-active-memory"]], "mlx.core.metal.get_cache_memory": [[165, "mlx-core-metal-get-cache-memory"]], "mlx.core.metal.get_peak_memory": [[166, "mlx-core-metal-get-peak-memory"]], "mlx.core.metal.is_available": [[167, "mlx-core-metal-is-available"]], "mlx.core.metal.set_cache_limit": [[168, "mlx-core-metal-set-cache-limit"]], "mlx.core.metal.set_memory_limit": [[169, "mlx-core-metal-set-memory-limit"]], "mlx.core.metal.start_capture": [[170, "mlx-core-metal-start-capture"]], "mlx.core.metal.stop_capture": [[171, "mlx-core-metal-stop-capture"]], "mlx.core.min": [[172, "mlx-core-min"]], "mlx.core.minimum": [[173, "mlx-core-minimum"]], "mlx.core.moveaxis": [[174, "mlx-core-moveaxis"]], "mlx.core.multiply": [[175, "mlx-core-multiply"]], "mlx.core.negative": [[176, "mlx-core-negative"]], "mlx.core.new_stream": [[177, "mlx-core-new-stream"]], "mlx.core.ones": [[178, "mlx-core-ones"]], "mlx.core.ones_like": [[179, "mlx-core-ones-like"]], "mlx.core.outer": [[180, "mlx-core-outer"]], "mlx.core.pad": [[181, "mlx-core-pad"]], "mlx.core.partition": [[182, "mlx-core-partition"]], "mlx.core.prod": [[183, "mlx-core-prod"]], "mlx.core.quantize": [[184, "mlx-core-quantize"]], "mlx.core.quantized_matmul": [[185, "mlx-core-quantized-matmul"]], "mlx.core.random.bernoulli": [[186, "mlx-core-random-bernoulli"]], "mlx.core.random.categorical": [[187, "mlx-core-random-categorical"]], "mlx.core.random.gumbel": [[188, "mlx-core-random-gumbel"]], "mlx.core.random.key": [[189, "mlx-core-random-key"]], "mlx.core.random.multivariate_normal": [[190, "mlx-core-random-multivariate-normal"]], "mlx.core.random.normal": [[191, "mlx-core-random-normal"]], "mlx.core.random.randint": [[192, "mlx-core-random-randint"]], "mlx.core.random.seed": [[193, "mlx-core-random-seed"]], "mlx.core.random.split": [[194, "mlx-core-random-split"]], "mlx.core.random.truncated_normal": [[195, "mlx-core-random-truncated-normal"]], "mlx.core.random.uniform": [[196, "mlx-core-random-uniform"]], "mlx.core.reciprocal": [[197, "mlx-core-reciprocal"]], "mlx.core.repeat": [[198, "mlx-core-repeat"]], "mlx.core.reshape": [[199, "mlx-core-reshape"]], "mlx.core.round": [[200, "mlx-core-round"]], "mlx.core.rsqrt": [[201, "mlx-core-rsqrt"]], "mlx.core.save": [[202, "mlx-core-save"]], "mlx.core.save_gguf": [[203, "mlx-core-save-gguf"]], "mlx.core.save_safetensors": [[204, "mlx-core-save-safetensors"]], "mlx.core.savez": [[205, "mlx-core-savez"]], "mlx.core.savez_compressed": [[206, "mlx-core-savez-compressed"]], "mlx.core.set_default_device": [[207, "mlx-core-set-default-device"]], "mlx.core.set_default_stream": [[208, "mlx-core-set-default-stream"]], "mlx.core.sigmoid": [[209, "mlx-core-sigmoid"]], "mlx.core.sign": [[210, "mlx-core-sign"]], "mlx.core.sin": [[211, "mlx-core-sin"]], "mlx.core.sinh": [[212, "mlx-core-sinh"]], "mlx.core.softmax": [[213, "mlx-core-softmax"]], "mlx.core.sort": [[214, "mlx-core-sort"]], "mlx.core.split": [[215, "mlx-core-split"]], "mlx.core.sqrt": [[216, "mlx-core-sqrt"]], "mlx.core.square": [[217, "mlx-core-square"]], "mlx.core.squeeze": [[218, "mlx-core-squeeze"]], "mlx.core.stack": [[219, "mlx-core-stack"]], "mlx.core.std": [[220, "mlx-core-std"]], "mlx.core.stop_gradient": [[221, "mlx-core-stop-gradient"]], "mlx.core.stream": [[222, "mlx-core-stream"]], "mlx.core.subtract": [[223, "mlx-core-subtract"]], "mlx.core.sum": [[224, "mlx-core-sum"]], "mlx.core.swapaxes": [[225, "mlx-core-swapaxes"]], "mlx.core.take": [[226, "mlx-core-take"]], "mlx.core.take_along_axis": [[227, "mlx-core-take-along-axis"]], "mlx.core.tan": [[228, "mlx-core-tan"]], "mlx.core.tanh": [[229, "mlx-core-tanh"]], "mlx.core.tensordot": [[230, "mlx-core-tensordot"]], "mlx.core.tile": [[231, "mlx-core-tile"]], "mlx.core.topk": [[232, "mlx-core-topk"]], "mlx.core.transpose": [[233, "mlx-core-transpose"]], "mlx.core.tri": [[234, "mlx-core-tri"]], "mlx.core.tril": [[235, "mlx-core-tril"]], "mlx.core.triu": [[236, "mlx-core-triu"]], "mlx.core.value_and_grad": [[237, "mlx-core-value-and-grad"]], "mlx.core.var": [[238, "mlx-core-var"]], "mlx.core.vjp": [[239, "mlx-core-vjp"]], "mlx.core.vmap": [[240, "mlx-core-vmap"]], "mlx.core.where": [[241, "mlx-core-where"]], "mlx.core.zeros": [[242, "mlx-core-zeros"]], "mlx.core.zeros_like": [[243, "mlx-core-zeros-like"]], "mlx.nn.value_and_grad": [[244, "mlx-nn-value-and-grad"]], "mlx.utils.tree_flatten": [[245, "mlx-utils-tree-flatten"]], "mlx.utils.tree_map": [[246, "mlx-utils-tree-map"]], "mlx.utils.tree_unflatten": [[247, "mlx-utils-tree-unflatten"]], "mlx.core.Stream": [[248, "mlx-core-stream"]], "Array": [[249, "array"]], "Data Types": [[250, "data-types"]], "Supported Data Types": [[250, "id2"]], "Devices and Streams": [[251, "devices-and-streams"]], "Fast": [[252, "fast"]], "FFT": [[253, "fft"]], "Linear Algebra": [[254, "linear-algebra"]], "Metal": [[255, "metal"]], "Neural Networks": [[256, "neural-networks"]], "Quick Start with Neural Networks": [[256, "quick-start-with-neural-networks"]], "The Module Class": [[256, "the-module-class"]], "Parameters": [[256, "parameters"]], "Updating the Parameters": [[256, "updating-the-parameters"]], "Inspecting Modules": [[256, "inspecting-modules"]], "Value and Grad": [[256, "value-and-grad"]], "mlx.nn.ALiBi": [[257, "mlx-nn-alibi"]], "mlx.nn.AvgPool1d": [[258, "mlx-nn-avgpool1d"]], "mlx.nn.AvgPool2d": [[259, "mlx-nn-avgpool2d"]], "mlx.nn.BatchNorm": [[260, "mlx-nn-batchnorm"]], "mlx.nn.Conv1d": [[261, "mlx-nn-conv1d"]], "mlx.nn.Conv2d": [[262, "mlx-nn-conv2d"]], "mlx.nn.Dropout": [[263, "mlx-nn-dropout"]], "mlx.nn.Dropout2d": [[264, "mlx-nn-dropout2d"]], "mlx.nn.Dropout3d": [[265, "mlx-nn-dropout3d"]], "mlx.nn.Embedding": [[266, "mlx-nn-embedding"]], "mlx.nn.GELU": [[267, "mlx-nn-gelu"]], "mlx.nn.GRU": [[268, "mlx-nn-gru"]], "mlx.nn.GroupNorm": [[269, "mlx-nn-groupnorm"]], "mlx.nn.InstanceNorm": [[270, "mlx-nn-instancenorm"]], "mlx.nn.LSTM": [[271, "mlx-nn-lstm"]], "mlx.nn.LayerNorm": [[272, "mlx-nn-layernorm"]], "mlx.nn.Linear": [[273, "mlx-nn-linear"]], "mlx.nn.MaxPool1d": [[274, "mlx-nn-maxpool1d"]], "mlx.nn.MaxPool2d": [[275, "mlx-nn-maxpool2d"]], "mlx.nn.Mish": [[276, "mlx-nn-mish"]], "mlx.nn.Module.apply": [[277, "mlx-nn-module-apply"]], "mlx.nn.Module.apply_to_modules": [[278, "mlx-nn-module-apply-to-modules"]], "mlx.nn.Module.children": [[279, "mlx-nn-module-children"]], "mlx.nn.Module.eval": [[280, "mlx-nn-module-eval"]], "mlx.nn.Module.filter_and_map": [[281, "mlx-nn-module-filter-and-map"]], "mlx.nn.Module.freeze": [[282, "mlx-nn-module-freeze"]], "mlx.nn.Module.leaf_modules": [[283, "mlx-nn-module-leaf-modules"]], "mlx.nn.Module.load_weights": [[284, "mlx-nn-module-load-weights"]], "mlx.nn.Module.modules": [[285, "mlx-nn-module-modules"]], "mlx.nn.Module.named_modules": [[286, "mlx-nn-module-named-modules"]], "mlx.nn.Module.parameters": [[287, "mlx-nn-module-parameters"]], "mlx.nn.Module.save_weights": [[288, "mlx-nn-module-save-weights"]], "mlx.nn.Module.set_dtype": [[289, "mlx-nn-module-set-dtype"]], "mlx.nn.Module.state": [[290, "mlx-nn-module-state"]], "mlx.nn.Module.train": [[291, "mlx-nn-module-train"]], "mlx.nn.Module.trainable_parameters": [[292, "mlx-nn-module-trainable-parameters"]], "mlx.nn.Module.training": [[293, "mlx-nn-module-training"]], "mlx.nn.Module.unfreeze": [[294, "mlx-nn-module-unfreeze"]], "mlx.nn.Module.update": [[295, "mlx-nn-module-update"]], "mlx.nn.Module.update_modules": [[296, "mlx-nn-module-update-modules"]], "mlx.nn.MultiHeadAttention": [[297, "mlx-nn-multiheadattention"]], "mlx.nn.PReLU": [[298, "mlx-nn-prelu"]], "mlx.nn.QuantizedLinear": [[299, "mlx-nn-quantizedlinear"]], "mlx.nn.RMSNorm": [[300, "mlx-nn-rmsnorm"]], "mlx.nn.RNN": [[301, "mlx-nn-rnn"]], "mlx.nn.ReLU": [[302, "mlx-nn-relu"]], "mlx.nn.RoPE": [[303, "mlx-nn-rope"]], "mlx.nn.SELU": [[304, "mlx-nn-selu"]], "mlx.nn.Sequential": [[305, "mlx-nn-sequential"]], "mlx.nn.SiLU": [[306, "mlx-nn-silu"]], "mlx.nn.SinusoidalPositionalEncoding": [[307, "mlx-nn-sinusoidalpositionalencoding"]], "mlx.nn.Softshrink": [[308, "mlx-nn-softshrink"]], "mlx.nn.Step": [[309, "mlx-nn-step"]], "mlx.nn.Transformer": [[310, "mlx-nn-transformer"]], "mlx.nn.Upsample": [[311, "mlx-nn-upsample"]], "mlx.nn.init.constant": [[312, "mlx-nn-init-constant"]], "mlx.nn.init.glorot_normal": [[313, "mlx-nn-init-glorot-normal"]], "mlx.nn.init.glorot_uniform": [[314, "mlx-nn-init-glorot-uniform"]], "mlx.nn.init.he_normal": [[315, "mlx-nn-init-he-normal"]], "mlx.nn.init.he_uniform": [[316, "mlx-nn-init-he-uniform"]], "mlx.nn.init.identity": [[317, "mlx-nn-init-identity"]], "mlx.nn.init.normal": [[318, "mlx-nn-init-normal"]], "mlx.nn.init.uniform": [[319, "mlx-nn-init-uniform"]], "mlx.nn.elu": [[320, "mlx-nn-elu"]], "mlx.nn.gelu": [[321, "mlx-nn-gelu"]], "mlx.nn.gelu_approx": [[322, "mlx-nn-gelu-approx"]], "mlx.nn.gelu_fast_approx": [[323, "mlx-nn-gelu-fast-approx"]], "mlx.nn.glu": [[324, "mlx-nn-glu"]], "mlx.nn.hardswish": [[325, "mlx-nn-hardswish"]], "mlx.nn.leaky_relu": [[326, "mlx-nn-leaky-relu"]], "mlx.nn.log_sigmoid": [[327, "mlx-nn-log-sigmoid"]], "mlx.nn.log_softmax": [[328, "mlx-nn-log-softmax"]], "mlx.nn.losses.binary_cross_entropy": [[329, "mlx-nn-losses-binary-cross-entropy"]], "mlx.nn.losses.cosine_similarity_loss": [[330, "mlx-nn-losses-cosine-similarity-loss"]], "mlx.nn.losses.cross_entropy": [[331, "mlx-nn-losses-cross-entropy"]], "mlx.nn.losses.gaussian_nll_loss": [[332, "mlx-nn-losses-gaussian-nll-loss"]], "mlx.nn.losses.hinge_loss": [[333, "mlx-nn-losses-hinge-loss"]], "mlx.nn.losses.huber_loss": [[334, "mlx-nn-losses-huber-loss"]], "mlx.nn.losses.kl_div_loss": [[335, "mlx-nn-losses-kl-div-loss"]], "mlx.nn.losses.l1_loss": [[336, "mlx-nn-losses-l1-loss"]], "mlx.nn.losses.log_cosh_loss": [[337, "mlx-nn-losses-log-cosh-loss"]], "mlx.nn.losses.margin_ranking_loss": [[338, "mlx-nn-losses-margin-ranking-loss"]], "mlx.nn.losses.mse_loss": [[339, "mlx-nn-losses-mse-loss"]], "mlx.nn.losses.nll_loss": [[340, "mlx-nn-losses-nll-loss"]], "mlx.nn.losses.smooth_l1_loss": [[341, "mlx-nn-losses-smooth-l1-loss"]], "mlx.nn.losses.triplet_loss": [[342, "mlx-nn-losses-triplet-loss"]], "mlx.nn.mish": [[343, "mlx-nn-mish"]], "mlx.nn.prelu": [[344, "mlx-nn-prelu"]], "mlx.nn.relu": [[345, "mlx-nn-relu"]], "mlx.nn.relu6": [[346, "mlx-nn-relu6"]], "mlx.nn.selu": [[347, "mlx-nn-selu"]], "mlx.nn.sigmoid": [[348, "mlx-nn-sigmoid"]], "mlx.nn.silu": [[349, "mlx-nn-silu"]], "mlx.nn.softmax": [[350, "mlx-nn-softmax"]], "mlx.nn.softplus": [[351, "mlx-nn-softplus"]], "mlx.nn.softshrink": [[352, "mlx-nn-softshrink"]], "mlx.nn.step": [[353, "mlx-nn-step"]], "mlx.nn.tanh": [[354, "mlx-nn-tanh"]], "Functions": [[355, "functions"]], "Initializers": [[356, "initializers"]], "Layers": [[357, "layers"]], "Loss Functions": [[358, "loss-functions"]], "Module": [[359, "module"]], "Optimizers": [[361, "optimizers"]], "mlx.optimizers.AdaDelta": [[362, "mlx-optimizers-adadelta"]], "mlx.optimizers.Adafactor": [[363, "mlx-optimizers-adafactor"]], "mlx.optimizers.Adagrad": [[364, "mlx-optimizers-adagrad"]], "mlx.optimizers.Adam": [[365, "mlx-optimizers-adam"]], "mlx.optimizers.AdamW": [[366, "mlx-optimizers-adamw"]], "mlx.optimizers.Adamax": [[367, "mlx-optimizers-adamax"]], "mlx.optimizers.Lion": [[368, "mlx-optimizers-lion"]], "mlx.optimizers.Optimizer.apply_gradients": [[369, "mlx-optimizers-optimizer-apply-gradients"]], "mlx.optimizers.Optimizer.init": [[370, "mlx-optimizers-optimizer-init"]], "mlx.optimizers.Optimizer.state": [[371, "mlx-optimizers-optimizer-state"]], "mlx.optimizers.Optimizer.update": [[372, "mlx-optimizers-optimizer-update"]], "mlx.optimizers.RMSprop": [[373, "mlx-optimizers-rmsprop"]], "mlx.optimizers.SGD": [[374, "mlx-optimizers-sgd"]], "mlx.optimizers.cosine_decay": [[375, "mlx-optimizers-cosine-decay"]], "mlx.optimizers.exponential_decay": [[376, "mlx-optimizers-exponential-decay"]], "mlx.optimizers.join_schedules": [[377, "mlx-optimizers-join-schedules"]], "mlx.optimizers.linear_schedule": [[378, "mlx-optimizers-linear-schedule"]], "mlx.optimizers.step_decay": [[379, "mlx-optimizers-step-decay"]], "Common Optimizers": [[380, "common-optimizers"]], "Optimizer": [[381, "optimizer"]], "Schedulers": [[382, "schedulers"]], "Random": [[383, "random"]], "Transforms": [[384, "transforms"]], "Tree Utils": [[385, "tree-utils"]], "Compilation": [[386, "compilation"]], "Basics of Compile": [[386, "basics-of-compile"]], "Example Speedup": [[386, "example-speedup"]], "Debugging": [[386, "debugging"]], "Pure Functions": [[386, "pure-functions"]], "Compiling Training Graphs": [[386, "compiling-training-graphs"]], "Transformations with Compile": [[386, "transformations-with-compile"]], "Function Transforms": [[387, "function-transforms"]], "Automatic Differentiation": [[387, "automatic-differentiation"]], "Automatic Vectorization": [[387, "automatic-vectorization"]], "Indexing Arrays": [[388, "indexing-arrays"]], "Differences from NumPy": [[388, "differences-from-numpy"]], "In Place Updates": [[388, "in-place-updates"]], "Lazy Evaluation": [[389, "lazy-evaluation"]], "Why Lazy Evaluation": [[389, "why-lazy-evaluation"]], "Transforming Compute Graphs": [[389, "transforming-compute-graphs"]], "Only Compute What You Use": [[389, "only-compute-what-you-use"]], "When to Evaluate": [[389, "when-to-evaluate"]], "Conversion to NumPy and Other Frameworks": [[390, "conversion-to-numpy-and-other-frameworks"]], "PyTorch": [[390, "pytorch"]], "JAX": [[390, "jax"]], "TensorFlow": [[390, "tensorflow"]], "Quick Start Guide": [[391, "quick-start-guide"]], "Basics": [[391, "basics"]], "Function and Graph Transformations": [[391, "function-and-graph-transformations"]], "Saving and Loading Arrays": [[392, "saving-and-loading-arrays"]], "Serialization Formats": [[392, "id1"]], "Unified Memory": [[393, "unified-memory"]], "A Simple Example": [[393, "a-simple-example"]], "Using Streams": [[394, "using-streams"]], "Specifying the Stream": [[394, "specifying-the-stream"]]}, "indexentries": {"device (class in mlx.core)": [[8, "mlx.core.Device"]], "__init__() (device method)": [[8, "mlx.core.Device.__init__"]], "dtype (class in mlx.core)": [[9, "mlx.core.Dtype"]], "__init__() (dtype method)": [[9, "mlx.core.Dtype.__init__"]], "dtypecategory (class in mlx.core)": [[10, "mlx.core.DtypeCategory"]], "__init__() (dtypecategory method)": [[10, "mlx.core.DtypeCategory.__init__"]], "abs() (in module mlx.core)": [[11, "mlx.core.abs"]], "add() (in module mlx.core)": [[12, "mlx.core.add"]], "all() (in module mlx.core)": [[13, "mlx.core.all"]], "allclose() (in module mlx.core)": [[14, "mlx.core.allclose"]], "any() (in module mlx.core)": [[15, "mlx.core.any"]], "arange() (in module mlx.core)": [[16, "mlx.core.arange"]], "arccos() (in module mlx.core)": [[17, "mlx.core.arccos"]], "arccosh() (in module mlx.core)": [[18, "mlx.core.arccosh"]], "arcsin() (in module mlx.core)": [[19, "mlx.core.arcsin"]], "arcsinh() (in module mlx.core)": [[20, "mlx.core.arcsinh"]], "arctan() (in module mlx.core)": [[21, "mlx.core.arctan"]], "arctanh() (in module mlx.core)": [[22, "mlx.core.arctanh"]], "argmax() (in module mlx.core)": [[23, "mlx.core.argmax"]], "argmin() (in module mlx.core)": [[24, "mlx.core.argmin"]], "argpartition() (in module mlx.core)": [[25, "mlx.core.argpartition"]], "argsort() (in module mlx.core)": [[26, "mlx.core.argsort"]], "__init__() (array method)": [[27, "mlx.core.array.__init__"]], "array (class in mlx.core)": [[27, "mlx.core.array"]], "t (array property)": [[28, "mlx.core.array.T"]], "abs() (array method)": [[29, "mlx.core.array.abs"]], "all() (array method)": [[30, "mlx.core.array.all"]], "any() (array method)": [[31, "mlx.core.array.any"]], "argmax() (array method)": [[32, "mlx.core.array.argmax"]], "argmin() (array method)": [[33, "mlx.core.array.argmin"]], "astype() (array method)": [[34, "mlx.core.array.astype"]], "at (array property)": [[35, "mlx.core.array.at"]], "cos() (array method)": [[36, "mlx.core.array.cos"]], "cummax() (array method)": [[37, "mlx.core.array.cummax"]], "cummin() (array method)": [[38, "mlx.core.array.cummin"]], "cumprod() (array method)": [[39, "mlx.core.array.cumprod"]], "cumsum() (array method)": [[40, "mlx.core.array.cumsum"]], "diag() (array method)": [[41, "mlx.core.array.diag"]], "diagonal() (array method)": [[42, "mlx.core.array.diagonal"]], "dtype (array property)": [[43, "mlx.core.array.dtype"]], "exp() (array method)": [[44, "mlx.core.array.exp"]], "flatten() (array method)": [[45, "mlx.core.array.flatten"]], "item() (array method)": [[46, "mlx.core.array.item"]], "itemsize (array property)": [[47, "mlx.core.array.itemsize"]], "log() (array method)": [[48, "mlx.core.array.log"]], "log10() (array method)": [[49, "mlx.core.array.log10"]], "log1p() (array method)": [[50, "mlx.core.array.log1p"]], "log2() (array method)": [[51, "mlx.core.array.log2"]], "logsumexp() (array method)": [[52, "mlx.core.array.logsumexp"]], "max() (array method)": [[53, "mlx.core.array.max"]], "mean() (array method)": [[54, "mlx.core.array.mean"]], "min() (array method)": [[55, "mlx.core.array.min"]], "moveaxis() (array method)": [[56, "mlx.core.array.moveaxis"]], "nbytes (array property)": [[57, "mlx.core.array.nbytes"]], "ndim (array property)": [[58, "mlx.core.array.ndim"]], "prod() (array method)": [[59, "mlx.core.array.prod"]], "reciprocal() (array method)": [[60, "mlx.core.array.reciprocal"]], "reshape() (array method)": [[61, "mlx.core.array.reshape"]], "round() (array method)": [[62, "mlx.core.array.round"]], "rsqrt() (array method)": [[63, "mlx.core.array.rsqrt"]], "shape (array property)": [[64, "mlx.core.array.shape"]], "sin() (array method)": [[65, "mlx.core.array.sin"]], "size (array property)": [[66, "mlx.core.array.size"]], "split() (array method)": [[67, "mlx.core.array.split"]], "sqrt() (array method)": [[68, "mlx.core.array.sqrt"]], "square() (array method)": [[69, "mlx.core.array.square"]], "squeeze() (array method)": [[70, "mlx.core.array.squeeze"]], "sum() (array method)": [[71, "mlx.core.array.sum"]], "swapaxes() (array method)": [[72, "mlx.core.array.swapaxes"]], "tolist() (array method)": [[73, "mlx.core.array.tolist"]], "transpose() (array method)": [[74, "mlx.core.array.transpose"]], "var() (array method)": [[75, "mlx.core.array.var"]], "array_equal() (in module mlx.core)": [[76, "mlx.core.array_equal"]], "atleast_1d() (in module mlx.core)": [[77, "mlx.core.atleast_1d"]], "atleast_2d() (in module mlx.core)": [[78, "mlx.core.atleast_2d"]], "atleast_3d() (in module mlx.core)": [[79, "mlx.core.atleast_3d"]], "broadcast_to() (in module mlx.core)": [[80, "mlx.core.broadcast_to"]], "ceil() (in module mlx.core)": [[81, "mlx.core.ceil"]], "clip() (in module mlx.core)": [[82, "mlx.core.clip"]], "compile() (in module mlx.core)": [[83, "mlx.core.compile"]], "concatenate() (in module mlx.core)": [[84, "mlx.core.concatenate"]], "conv1d() (in module mlx.core)": [[85, "mlx.core.conv1d"]], "conv2d() (in module mlx.core)": [[86, "mlx.core.conv2d"]], "conv_general() (in module mlx.core)": [[87, "mlx.core.conv_general"]], "convolve() (in module mlx.core)": [[88, "mlx.core.convolve"]], "cos() (in module mlx.core)": [[89, "mlx.core.cos"]], "cosh() (in module mlx.core)": [[90, "mlx.core.cosh"]], "cummax() (in module mlx.core)": [[91, "mlx.core.cummax"]], "cummin() (in module mlx.core)": [[92, "mlx.core.cummin"]], "cumprod() (in module mlx.core)": [[93, "mlx.core.cumprod"]], "cumsum() (in module mlx.core)": [[94, "mlx.core.cumsum"]], "default_device() (in module mlx.core)": [[95, "mlx.core.default_device"]], "default_stream() (in module mlx.core)": [[96, "mlx.core.default_stream"]], "dequantize() (in module mlx.core)": [[97, "mlx.core.dequantize"]], "diag() (in module mlx.core)": [[98, "mlx.core.diag"]], "diagonal() (in module mlx.core)": [[99, "mlx.core.diagonal"]], "disable_compile() (in module mlx.core)": [[100, "mlx.core.disable_compile"]], "divide() (in module mlx.core)": [[101, "mlx.core.divide"]], "divmod() (in module mlx.core)": [[102, "mlx.core.divmod"]], "enable_compile() (in module mlx.core)": [[103, "mlx.core.enable_compile"]], "equal() (in module mlx.core)": [[104, "mlx.core.equal"]], "erf() (in module mlx.core)": [[105, "mlx.core.erf"]], "erfinv() (in module mlx.core)": [[106, "mlx.core.erfinv"]], "eval() (in module mlx.core)": [[107, "mlx.core.eval"]], "exp() (in module mlx.core)": [[108, "mlx.core.exp"]], "expand_dims() (in module mlx.core)": [[109, "mlx.core.expand_dims"]], "expm1() (in module mlx.core)": [[110, "mlx.core.expm1"]], "eye() (in module mlx.core)": [[111, "mlx.core.eye"]], "layer_norm() (in module mlx.core.fast)": [[112, "mlx.core.fast.layer_norm"]], "rms_norm() (in module mlx.core.fast)": [[113, "mlx.core.fast.rms_norm"]], "rope() (in module mlx.core.fast)": [[114, "mlx.core.fast.rope"]], "scaled_dot_product_attention() (in module mlx.core.fast)": [[115, "mlx.core.fast.scaled_dot_product_attention"]], "fft() (in module mlx.core.fft)": [[116, "mlx.core.fft.fft"]], "fft2() (in module mlx.core.fft)": [[117, "mlx.core.fft.fft2"]], "fftn() (in module mlx.core.fft)": [[118, "mlx.core.fft.fftn"]], "ifft() (in module mlx.core.fft)": [[119, "mlx.core.fft.ifft"]], "ifft2() (in module mlx.core.fft)": [[120, "mlx.core.fft.ifft2"]], "ifftn() (in module mlx.core.fft)": [[121, "mlx.core.fft.ifftn"]], "irfft() (in module mlx.core.fft)": [[122, "mlx.core.fft.irfft"]], "irfft2() (in module mlx.core.fft)": [[123, "mlx.core.fft.irfft2"]], "irfftn() (in module mlx.core.fft)": [[124, "mlx.core.fft.irfftn"]], "rfft() (in module mlx.core.fft)": [[125, "mlx.core.fft.rfft"]], "rfft2() (in module mlx.core.fft)": [[126, "mlx.core.fft.rfft2"]], "rfftn() (in module mlx.core.fft)": [[127, "mlx.core.fft.rfftn"]], "flatten() (in module mlx.core)": [[128, "mlx.core.flatten"]], "floor() (in module mlx.core)": [[129, "mlx.core.floor"]], "floor_divide() (in module mlx.core)": [[130, "mlx.core.floor_divide"]], "full() (in module mlx.core)": [[131, "mlx.core.full"]], "grad() (in module mlx.core)": [[132, "mlx.core.grad"]], "greater() (in module mlx.core)": [[133, "mlx.core.greater"]], "greater_equal() (in module mlx.core)": [[134, "mlx.core.greater_equal"]], "identity() (in module mlx.core)": [[135, "mlx.core.identity"]], "inner() (in module mlx.core)": [[136, "mlx.core.inner"]], "isclose() (in module mlx.core)": [[137, "mlx.core.isclose"]], "isinf() (in module mlx.core)": [[138, "mlx.core.isinf"]], "isnan() (in module mlx.core)": [[139, "mlx.core.isnan"]], "isneginf() (in module mlx.core)": [[140, "mlx.core.isneginf"]], "isposinf() (in module mlx.core)": [[141, "mlx.core.isposinf"]], "issubdtype() (in module mlx.core)": [[142, "mlx.core.issubdtype"]], "jvp() (in module mlx.core)": [[143, "mlx.core.jvp"]], "less() (in module mlx.core)": [[144, "mlx.core.less"]], "less_equal() (in module mlx.core)": [[145, "mlx.core.less_equal"]], "norm() (in module mlx.core.linalg)": [[146, "mlx.core.linalg.norm"]], "qr() (in module mlx.core.linalg)": [[147, "mlx.core.linalg.qr"]], "linspace() (in module mlx.core)": [[148, "mlx.core.linspace"]], "load() (in module mlx.core)": [[149, "mlx.core.load"]], "log() (in module mlx.core)": [[150, "mlx.core.log"]], "log10() (in module mlx.core)": [[151, "mlx.core.log10"]], "log1p() (in module mlx.core)": [[152, "mlx.core.log1p"]], "log2() (in module mlx.core)": [[153, "mlx.core.log2"]], "logaddexp() (in module mlx.core)": [[154, "mlx.core.logaddexp"]], "logical_and() (in module mlx.core)": [[155, "mlx.core.logical_and"]], "logical_not() (in module mlx.core)": [[156, "mlx.core.logical_not"]], "logical_or() (in module mlx.core)": [[157, "mlx.core.logical_or"]], "logsumexp() (in module mlx.core)": [[158, "mlx.core.logsumexp"]], "matmul() (in module mlx.core)": [[159, "mlx.core.matmul"]], "max() (in module mlx.core)": [[160, "mlx.core.max"]], "maximum() (in module mlx.core)": [[161, "mlx.core.maximum"]], "mean() (in module mlx.core)": [[162, "mlx.core.mean"]], "meshgrid() (in module mlx.core)": [[163, "mlx.core.meshgrid"]], "get_active_memory() (in module mlx.core.metal)": [[164, "mlx.core.metal.get_active_memory"]], "get_cache_memory() (in module mlx.core.metal)": [[165, "mlx.core.metal.get_cache_memory"]], "get_peak_memory() (in module mlx.core.metal)": [[166, "mlx.core.metal.get_peak_memory"]], "is_available() (in module mlx.core.metal)": [[167, "mlx.core.metal.is_available"]], "set_cache_limit() (in module mlx.core.metal)": [[168, "mlx.core.metal.set_cache_limit"]], "set_memory_limit() (in module mlx.core.metal)": [[169, "mlx.core.metal.set_memory_limit"]], "start_capture() (in module mlx.core.metal)": [[170, "mlx.core.metal.start_capture"]], "stop_capture() (in module mlx.core.metal)": [[171, "mlx.core.metal.stop_capture"]], "min() (in module mlx.core)": [[172, "mlx.core.min"]], "minimum() (in module mlx.core)": [[173, "mlx.core.minimum"]], "moveaxis() (in module mlx.core)": [[174, "mlx.core.moveaxis"]], "multiply() (in module mlx.core)": [[175, "mlx.core.multiply"]], "negative() (in module mlx.core)": [[176, "mlx.core.negative"]], "new_stream() (in module mlx.core)": [[177, "mlx.core.new_stream"]], "ones() (in module mlx.core)": [[178, "mlx.core.ones"]], "ones_like() (in module mlx.core)": [[179, "mlx.core.ones_like"]], "outer() (in module mlx.core)": [[180, "mlx.core.outer"]], "pad() (in module mlx.core)": [[181, "mlx.core.pad"]], "partition() (in module mlx.core)": [[182, "mlx.core.partition"]], "prod() (in module mlx.core)": [[183, "mlx.core.prod"]], "quantize() (in module mlx.core)": [[184, "mlx.core.quantize"]], "quantized_matmul() (in module mlx.core)": [[185, "mlx.core.quantized_matmul"]], "bernoulli() (in module mlx.core.random)": [[186, "mlx.core.random.bernoulli"]], "categorical() (in module mlx.core.random)": [[187, "mlx.core.random.categorical"]], "gumbel() (in module mlx.core.random)": [[188, "mlx.core.random.gumbel"]], "key() (in module mlx.core.random)": [[189, "mlx.core.random.key"]], "multivariate_normal() (in module mlx.core.random)": [[190, "mlx.core.random.multivariate_normal"]], "normal() (in module mlx.core.random)": [[191, "mlx.core.random.normal"]], "randint() (in module mlx.core.random)": [[192, "mlx.core.random.randint"]], "seed() (in module mlx.core.random)": [[193, "mlx.core.random.seed"]], "split() (in module mlx.core.random)": [[194, "mlx.core.random.split"]], "truncated_normal() (in module mlx.core.random)": [[195, "mlx.core.random.truncated_normal"]], "uniform() (in module mlx.core.random)": [[196, "mlx.core.random.uniform"]], "reciprocal() (in module mlx.core)": [[197, "mlx.core.reciprocal"]], "repeat() (in module mlx.core)": [[198, "mlx.core.repeat"]], "reshape() (in module mlx.core)": [[199, "mlx.core.reshape"]], "round() (in module mlx.core)": [[200, "mlx.core.round"]], "rsqrt() (in module mlx.core)": [[201, "mlx.core.rsqrt"]], "save() (in module mlx.core)": [[202, "mlx.core.save"]], "save_gguf() (in module mlx.core)": [[203, "mlx.core.save_gguf"]], "save_safetensors() (in module mlx.core)": [[204, "mlx.core.save_safetensors"]], "savez() (in module mlx.core)": [[205, "mlx.core.savez"]], "savez_compressed() (in module mlx.core)": [[206, "mlx.core.savez_compressed"]], "set_default_device() (in module mlx.core)": [[207, "mlx.core.set_default_device"]], "set_default_stream() (in module mlx.core)": [[208, "mlx.core.set_default_stream"]], "sigmoid() (in module mlx.core)": [[209, "mlx.core.sigmoid"]], "sign() (in module mlx.core)": [[210, "mlx.core.sign"]], "sin() (in module mlx.core)": [[211, "mlx.core.sin"]], "sinh() (in module mlx.core)": [[212, "mlx.core.sinh"]], "softmax() (in module mlx.core)": [[213, "mlx.core.softmax"]], "sort() (in module mlx.core)": [[214, "mlx.core.sort"]], "split() (in module mlx.core)": [[215, "mlx.core.split"]], "sqrt() (in module mlx.core)": [[216, "mlx.core.sqrt"]], "square() (in module mlx.core)": [[217, "mlx.core.square"]], "squeeze() (in module mlx.core)": [[218, "mlx.core.squeeze"]], "stack() (in module mlx.core)": [[219, "mlx.core.stack"]], "std() (in module mlx.core)": [[220, "mlx.core.std"]], "stop_gradient() (in module mlx.core)": [[221, "mlx.core.stop_gradient"]], "stream() (in module mlx.core)": [[222, "mlx.core.stream"]], "subtract() (in module mlx.core)": [[223, "mlx.core.subtract"]], "sum() (in module mlx.core)": [[224, "mlx.core.sum"]], "swapaxes() (in module mlx.core)": [[225, "mlx.core.swapaxes"]], "take() (in module mlx.core)": [[226, "mlx.core.take"]], "take_along_axis() (in module mlx.core)": [[227, "mlx.core.take_along_axis"]], "tan() (in module mlx.core)": [[228, "mlx.core.tan"]], "tanh() (in module mlx.core)": [[229, "mlx.core.tanh"]], "tensordot() (in module mlx.core)": [[230, "mlx.core.tensordot"]], "tile() (in module mlx.core)": [[231, "mlx.core.tile"]], "topk() (in module mlx.core)": [[232, "mlx.core.topk"]], "transpose() (in module mlx.core)": [[233, "mlx.core.transpose"]], "tri() (in module mlx.core)": [[234, "mlx.core.tri"]], "tril() (in module mlx.core)": [[235, "mlx.core.tril"]], "triu() (in module mlx.core)": [[236, "mlx.core.triu"]], "value_and_grad() (in module mlx.core)": [[237, "mlx.core.value_and_grad"]], "var() (in module mlx.core)": [[238, "mlx.core.var"]], "vjp() (in module mlx.core)": [[239, "mlx.core.vjp"]], "vmap() (in module mlx.core)": [[240, "mlx.core.vmap"]], "where() (in module mlx.core)": [[241, "mlx.core.where"]], "zeros() (in module mlx.core)": [[242, "mlx.core.zeros"]], "zeros_like() (in module mlx.core)": [[243, "mlx.core.zeros_like"]], "value_and_grad() (in module mlx.nn)": [[244, "mlx.nn.value_and_grad"]], "tree_flatten() (in module mlx.utils)": [[245, "mlx.utils.tree_flatten"]], "tree_map() (in module mlx.utils)": [[246, "mlx.utils.tree_map"]], "tree_unflatten() (in module mlx.utils)": [[247, "mlx.utils.tree_unflatten"]], "stream (class in mlx.core)": [[248, "mlx.core.Stream"]], "__init__() (stream method)": [[248, "mlx.core.Stream.__init__"]], "alibi (class in mlx.nn)": [[257, "mlx.nn.ALiBi"]], "avgpool1d (class in mlx.nn)": [[258, "mlx.nn.AvgPool1d"]], "avgpool2d (class in mlx.nn)": [[259, "mlx.nn.AvgPool2d"]], "batchnorm (class in mlx.nn)": [[260, "mlx.nn.BatchNorm"]], "conv1d (class in mlx.nn)": [[261, "mlx.nn.Conv1d"]], "conv2d (class in mlx.nn)": [[262, "mlx.nn.Conv2d"]], "dropout (class in mlx.nn)": [[263, "mlx.nn.Dropout"]], "dropout2d (class in mlx.nn)": [[264, "mlx.nn.Dropout2d"]], "dropout3d (class in mlx.nn)": [[265, "mlx.nn.Dropout3d"]], "embedding (class in mlx.nn)": [[266, "mlx.nn.Embedding"]], "gelu (class in mlx.nn)": [[267, "mlx.nn.GELU"]], "gru (class in mlx.nn)": [[268, "mlx.nn.GRU"]], "groupnorm (class in mlx.nn)": [[269, "mlx.nn.GroupNorm"]], "instancenorm (class in mlx.nn)": [[270, "mlx.nn.InstanceNorm"]], "lstm (class in mlx.nn)": [[271, "mlx.nn.LSTM"]], "layernorm (class in mlx.nn)": [[272, "mlx.nn.LayerNorm"]], "linear (class in mlx.nn)": [[273, "mlx.nn.Linear"]], "maxpool1d (class in mlx.nn)": [[274, "mlx.nn.MaxPool1d"]], "maxpool2d (class in mlx.nn)": [[275, "mlx.nn.MaxPool2d"]], "mish (class in mlx.nn)": [[276, "mlx.nn.Mish"]], "apply() (module method)": [[277, "mlx.nn.Module.apply"]], "apply_to_modules() (module method)": [[278, "mlx.nn.Module.apply_to_modules"]], "children() (module method)": [[279, "mlx.nn.Module.children"]], "eval() (module method)": [[280, "mlx.nn.Module.eval"]], "filter_and_map() (module method)": [[281, "mlx.nn.Module.filter_and_map"]], "freeze() (module method)": [[282, "mlx.nn.Module.freeze"]], "leaf_modules() (module method)": [[283, "mlx.nn.Module.leaf_modules"]], "load_weights() (module method)": [[284, "mlx.nn.Module.load_weights"]], "modules() (module method)": [[285, "mlx.nn.Module.modules"]], "named_modules() (module method)": [[286, "mlx.nn.Module.named_modules"]], "parameters() (module method)": [[287, "mlx.nn.Module.parameters"]], "save_weights() (module method)": [[288, "mlx.nn.Module.save_weights"]], "set_dtype() (module method)": [[289, "mlx.nn.Module.set_dtype"]], "state (module property)": [[290, "mlx.nn.Module.state"]], "train() (module method)": [[291, "mlx.nn.Module.train"]], "trainable_parameters() (module method)": [[292, "mlx.nn.Module.trainable_parameters"]], "training (module property)": [[293, "mlx.nn.Module.training"]], "unfreeze() (module method)": [[294, "mlx.nn.Module.unfreeze"]], "update() (module method)": [[295, "mlx.nn.Module.update"]], "update_modules() (module method)": [[296, "mlx.nn.Module.update_modules"]], "multiheadattention (class in mlx.nn)": [[297, "mlx.nn.MultiHeadAttention"]], "prelu (class in mlx.nn)": [[298, "mlx.nn.PReLU"]], "quantizedlinear (class in mlx.nn)": [[299, "mlx.nn.QuantizedLinear"]], "rmsnorm (class in mlx.nn)": [[300, "mlx.nn.RMSNorm"]], "rnn (class in mlx.nn)": [[301, "mlx.nn.RNN"]], "relu (class in mlx.nn)": [[302, "mlx.nn.ReLU"]], "rope (class in mlx.nn)": [[303, "mlx.nn.RoPE"]], "selu (class in mlx.nn)": [[304, "mlx.nn.SELU"]], "sequential (class in mlx.nn)": [[305, "mlx.nn.Sequential"]], "silu (class in mlx.nn)": [[306, "mlx.nn.SiLU"]], "sinusoidalpositionalencoding (class in mlx.nn)": [[307, "mlx.nn.SinusoidalPositionalEncoding"]], "softshrink (class in mlx.nn)": [[308, "mlx.nn.Softshrink"]], "step (class in mlx.nn)": [[309, "mlx.nn.Step"]], "transformer (class in mlx.nn)": [[310, "mlx.nn.Transformer"]], "upsample (class in mlx.nn)": [[311, "mlx.nn.Upsample"]], "constant() (in module mlx.nn.init)": [[312, "mlx.nn.init.constant"]], "glorot_normal() (in module mlx.nn.init)": [[313, "mlx.nn.init.glorot_normal"]], "glorot_uniform() (in module mlx.nn.init)": [[314, "mlx.nn.init.glorot_uniform"]], "he_normal() (in module mlx.nn.init)": [[315, "mlx.nn.init.he_normal"]], "he_uniform() (in module mlx.nn.init)": [[316, "mlx.nn.init.he_uniform"]], "identity() (in module mlx.nn.init)": [[317, "mlx.nn.init.identity"]], "normal() (in module mlx.nn.init)": [[318, "mlx.nn.init.normal"]], "uniform() (in module mlx.nn.init)": [[319, "mlx.nn.init.uniform"]], "elu() (in module mlx.nn)": [[320, "mlx.nn.elu"]], "gelu() (in module mlx.nn)": [[321, "mlx.nn.gelu"]], "gelu_approx() (in module mlx.nn)": [[322, "mlx.nn.gelu_approx"]], "gelu_fast_approx() (in module mlx.nn)": [[323, "mlx.nn.gelu_fast_approx"]], "glu() (in module mlx.nn)": [[324, "mlx.nn.glu"]], "hardswish() (in module mlx.nn)": [[325, "mlx.nn.hardswish"]], "leaky_relu() (in module mlx.nn)": [[326, "mlx.nn.leaky_relu"]], "log_sigmoid() (in module mlx.nn)": [[327, "mlx.nn.log_sigmoid"]], "log_softmax() (in module mlx.nn)": [[328, "mlx.nn.log_softmax"]], "binary_cross_entropy() (in module mlx.nn.losses)": [[329, "mlx.nn.losses.binary_cross_entropy"]], "cosine_similarity_loss() (in module mlx.nn.losses)": [[330, "mlx.nn.losses.cosine_similarity_loss"]], "cross_entropy() (in module mlx.nn.losses)": [[331, "mlx.nn.losses.cross_entropy"]], "gaussian_nll_loss() (in module mlx.nn.losses)": [[332, "mlx.nn.losses.gaussian_nll_loss"]], "hinge_loss() (in module mlx.nn.losses)": [[333, "mlx.nn.losses.hinge_loss"]], "huber_loss() (in module mlx.nn.losses)": [[334, "mlx.nn.losses.huber_loss"]], "kl_div_loss() (in module mlx.nn.losses)": [[335, "mlx.nn.losses.kl_div_loss"]], "l1_loss() (in module mlx.nn.losses)": [[336, "mlx.nn.losses.l1_loss"]], "log_cosh_loss() (in module mlx.nn.losses)": [[337, "mlx.nn.losses.log_cosh_loss"]], "margin_ranking_loss() (in module mlx.nn.losses)": [[338, "mlx.nn.losses.margin_ranking_loss"]], "mse_loss() (in module mlx.nn.losses)": [[339, "mlx.nn.losses.mse_loss"]], "nll_loss() (in module mlx.nn.losses)": [[340, "mlx.nn.losses.nll_loss"]], "smooth_l1_loss() (in module mlx.nn.losses)": [[341, "mlx.nn.losses.smooth_l1_loss"]], "triplet_loss() (in module mlx.nn.losses)": [[342, "mlx.nn.losses.triplet_loss"]], "mish() (in module mlx.nn)": [[343, "mlx.nn.mish"]], "prelu() (in module mlx.nn)": [[344, "mlx.nn.prelu"]], "relu() (in module mlx.nn)": [[345, "mlx.nn.relu"]], "relu6() (in module mlx.nn)": [[346, "mlx.nn.relu6"]], "selu() (in module mlx.nn)": [[347, "mlx.nn.selu"]], "sigmoid() (in module mlx.nn)": [[348, "mlx.nn.sigmoid"]], "silu() (in module mlx.nn)": [[349, "mlx.nn.silu"]], "softmax() (in module mlx.nn)": [[350, "mlx.nn.softmax"]], "softplus() (in module mlx.nn)": [[351, "mlx.nn.softplus"]], "softshrink() (in module mlx.nn)": [[352, "mlx.nn.softshrink"]], "step() (in module mlx.nn)": [[353, "mlx.nn.step"]], "tanh() (in module mlx.nn)": [[354, "mlx.nn.tanh"]], "module (class in mlx.nn)": [[359, "mlx.nn.Module"]], "adadelta (class in mlx.optimizers)": [[362, "mlx.optimizers.AdaDelta"]], "adafactor (class in mlx.optimizers)": [[363, "mlx.optimizers.Adafactor"]], "adagrad (class in mlx.optimizers)": [[364, "mlx.optimizers.Adagrad"]], "adam (class in mlx.optimizers)": [[365, "mlx.optimizers.Adam"]], "adamw (class in mlx.optimizers)": [[366, "mlx.optimizers.AdamW"]], "adamax (class in mlx.optimizers)": [[367, "mlx.optimizers.Adamax"]], "lion (class in mlx.optimizers)": [[368, "mlx.optimizers.Lion"]], "apply_gradients() (optimizer method)": [[369, "mlx.optimizers.Optimizer.apply_gradients"]], "init() (optimizer method)": [[370, "mlx.optimizers.Optimizer.init"]], "state (optimizer property)": [[371, "mlx.optimizers.Optimizer.state"]], "update() (optimizer method)": [[372, "mlx.optimizers.Optimizer.update"]], "rmsprop (class in mlx.optimizers)": [[373, "mlx.optimizers.RMSprop"]], "sgd (class in mlx.optimizers)": [[374, "mlx.optimizers.SGD"]], "cosine_decay() (in module mlx.optimizers)": [[375, "mlx.optimizers.cosine_decay"]], "exponential_decay() (in module mlx.optimizers)": [[376, "mlx.optimizers.exponential_decay"]], "join_schedules() (in module mlx.optimizers)": [[377, "mlx.optimizers.join_schedules"]], "linear_schedule() (in module mlx.optimizers)": [[378, "mlx.optimizers.linear_schedule"]], "step_decay() (in module mlx.optimizers)": [[379, "mlx.optimizers.step_decay"]], "optimizer (class in mlx.optimizers)": [[381, "mlx.optimizers.Optimizer"]]}}) \ No newline at end of file diff --git a/docs/build/html/usage/compile.html b/docs/build/html/usage/compile.html index e647a9484..76c5eea2a 100644 --- a/docs/build/html/usage/compile.html +++ b/docs/build/html/usage/compile.html @@ -8,7 +8,7 @@ - Compilation — MLX 0.9.0 documentation + Compilation — MLX 0.10.0 documentation @@ -36,7 +36,7 @@ - + @@ -131,8 +131,8 @@ - MLX 0.9.0 documentation - Home - + MLX 0.10.0 documentation - Home + @@ -286,6 +286,7 @@
    • mlx.core.erf
    • mlx.core.erfinv
    • mlx.core.exp
    • +
    • mlx.core.expm1
    • mlx.core.expand_dims
    • mlx.core.eye
    • mlx.core.flatten
    • @@ -318,6 +319,7 @@
    • mlx.core.max
    • mlx.core.maximum
    • mlx.core.mean
    • +
    • mlx.core.meshgrid
    • mlx.core.min
    • mlx.core.minimum
    • mlx.core.moveaxis
    • @@ -352,6 +354,7 @@
    • mlx.core.square
    • mlx.core.squeeze
    • mlx.core.stack
    • +
    • mlx.core.std
    • mlx.core.stop_gradient
    • mlx.core.subtract
    • mlx.core.sum
    • @@ -379,6 +382,7 @@
    • mlx.core.random.gumbel
    • mlx.core.random.key
    • mlx.core.random.normal
    • +
    • mlx.core.random.multivariate_normal
    • mlx.core.random.randint
    • mlx.core.random.seed
    • mlx.core.random.split
    • @@ -432,6 +436,8 @@
    • mlx.core.metal.get_cache_memory
    • mlx.core.metal.set_memory_limit
    • mlx.core.metal.set_cache_limit
    • +
    • mlx.core.metal.start_capture
    • +
    • mlx.core.metal.stop_capture
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks
  • Neural Networks