mlx/docs/build/html/searchindex.js
2025-06-04 01:01:48 +00:00

1 line
127 KiB
JavaScript

Search.setIndex({"docnames": ["cpp/ops", "dev/extensions", "examples/linear_regression", "examples/llama-inference", "examples/mlp", "index", "install", "python/_autosummary/mlx.core.Device", "python/_autosummary/mlx.core.Dtype", "python/_autosummary/mlx.core.Stream", "python/_autosummary/mlx.core.abs", "python/_autosummary/mlx.core.add", "python/_autosummary/mlx.core.all", "python/_autosummary/mlx.core.allclose", "python/_autosummary/mlx.core.any", "python/_autosummary/mlx.core.arange", "python/_autosummary/mlx.core.arccos", "python/_autosummary/mlx.core.arccosh", "python/_autosummary/mlx.core.arcsin", "python/_autosummary/mlx.core.arcsinh", "python/_autosummary/mlx.core.arctan", "python/_autosummary/mlx.core.arctanh", "python/_autosummary/mlx.core.argmax", "python/_autosummary/mlx.core.argmin", "python/_autosummary/mlx.core.argpartition", "python/_autosummary/mlx.core.argsort", "python/_autosummary/mlx.core.array", "python/_autosummary/mlx.core.array.T", "python/_autosummary/mlx.core.array.abs", "python/_autosummary/mlx.core.array.all", "python/_autosummary/mlx.core.array.any", "python/_autosummary/mlx.core.array.argmax", "python/_autosummary/mlx.core.array.argmin", "python/_autosummary/mlx.core.array.astype", "python/_autosummary/mlx.core.array.cos", "python/_autosummary/mlx.core.array.dtype", "python/_autosummary/mlx.core.array.exp", "python/_autosummary/mlx.core.array.item", "python/_autosummary/mlx.core.array.log", "python/_autosummary/mlx.core.array.log1p", "python/_autosummary/mlx.core.array.logsumexp", "python/_autosummary/mlx.core.array.max", "python/_autosummary/mlx.core.array.mean", "python/_autosummary/mlx.core.array.min", "python/_autosummary/mlx.core.array.ndim", "python/_autosummary/mlx.core.array.prod", "python/_autosummary/mlx.core.array.reciprocal", "python/_autosummary/mlx.core.array.reshape", "python/_autosummary/mlx.core.array.round", "python/_autosummary/mlx.core.array.rsqrt", "python/_autosummary/mlx.core.array.shape", "python/_autosummary/mlx.core.array.sin", "python/_autosummary/mlx.core.array.size", "python/_autosummary/mlx.core.array.split", "python/_autosummary/mlx.core.array.sqrt", "python/_autosummary/mlx.core.array.square", "python/_autosummary/mlx.core.array.sum", "python/_autosummary/mlx.core.array.tolist", "python/_autosummary/mlx.core.array.transpose", "python/_autosummary/mlx.core.array.var", "python/_autosummary/mlx.core.array_equal", "python/_autosummary/mlx.core.broadcast_to", "python/_autosummary/mlx.core.ceil", "python/_autosummary/mlx.core.clip", "python/_autosummary/mlx.core.concatenate", "python/_autosummary/mlx.core.conv1d", "python/_autosummary/mlx.core.conv2d", "python/_autosummary/mlx.core.convolve", "python/_autosummary/mlx.core.cos", "python/_autosummary/mlx.core.cosh", "python/_autosummary/mlx.core.default_device", "python/_autosummary/mlx.core.default_stream", "python/_autosummary/mlx.core.dequantize", "python/_autosummary/mlx.core.diag", "python/_autosummary/mlx.core.diagonal", "python/_autosummary/mlx.core.divide", "python/_autosummary/mlx.core.divmod", "python/_autosummary/mlx.core.equal", "python/_autosummary/mlx.core.erf", "python/_autosummary/mlx.core.erfinv", "python/_autosummary/mlx.core.eval", "python/_autosummary/mlx.core.exp", "python/_autosummary/mlx.core.expand_dims", "python/_autosummary/mlx.core.eye", "python/_autosummary/mlx.core.fft.fft", "python/_autosummary/mlx.core.fft.fft2", "python/_autosummary/mlx.core.fft.fftn", "python/_autosummary/mlx.core.fft.ifft", "python/_autosummary/mlx.core.fft.ifft2", "python/_autosummary/mlx.core.fft.ifftn", "python/_autosummary/mlx.core.fft.irfft", "python/_autosummary/mlx.core.fft.irfft2", "python/_autosummary/mlx.core.fft.irfftn", "python/_autosummary/mlx.core.fft.rfft", "python/_autosummary/mlx.core.fft.rfft2", "python/_autosummary/mlx.core.fft.rfftn", "python/_autosummary/mlx.core.flatten", "python/_autosummary/mlx.core.floor", "python/_autosummary/mlx.core.floor_divide", "python/_autosummary/mlx.core.full", "python/_autosummary/mlx.core.grad", "python/_autosummary/mlx.core.greater", "python/_autosummary/mlx.core.greater_equal", "python/_autosummary/mlx.core.identity", "python/_autosummary/mlx.core.inner", "python/_autosummary/mlx.core.isinf", "python/_autosummary/mlx.core.isnan", "python/_autosummary/mlx.core.isneginf", "python/_autosummary/mlx.core.isposinf", "python/_autosummary/mlx.core.jvp", "python/_autosummary/mlx.core.less", "python/_autosummary/mlx.core.less_equal", "python/_autosummary/mlx.core.linalg.norm", "python/_autosummary/mlx.core.linalg.qr", "python/_autosummary/mlx.core.linspace", "python/_autosummary/mlx.core.load", "python/_autosummary/mlx.core.log", "python/_autosummary/mlx.core.log10", "python/_autosummary/mlx.core.log1p", "python/_autosummary/mlx.core.log2", "python/_autosummary/mlx.core.logaddexp", "python/_autosummary/mlx.core.logical_and", "python/_autosummary/mlx.core.logical_not", "python/_autosummary/mlx.core.logical_or", "python/_autosummary/mlx.core.logsumexp", "python/_autosummary/mlx.core.matmul", "python/_autosummary/mlx.core.max", "python/_autosummary/mlx.core.maximum", "python/_autosummary/mlx.core.mean", "python/_autosummary/mlx.core.min", "python/_autosummary/mlx.core.minimum", "python/_autosummary/mlx.core.moveaxis", "python/_autosummary/mlx.core.multiply", "python/_autosummary/mlx.core.negative", "python/_autosummary/mlx.core.new_stream", "python/_autosummary/mlx.core.ones", "python/_autosummary/mlx.core.ones_like", "python/_autosummary/mlx.core.outer", "python/_autosummary/mlx.core.pad", "python/_autosummary/mlx.core.partition", "python/_autosummary/mlx.core.prod", "python/_autosummary/mlx.core.quantize", "python/_autosummary/mlx.core.quantized_matmul", "python/_autosummary/mlx.core.random.bernoulli", "python/_autosummary/mlx.core.random.categorical", "python/_autosummary/mlx.core.random.gumbel", "python/_autosummary/mlx.core.random.key", "python/_autosummary/mlx.core.random.normal", "python/_autosummary/mlx.core.random.randint", "python/_autosummary/mlx.core.random.seed", "python/_autosummary/mlx.core.random.split", "python/_autosummary/mlx.core.random.truncated_normal", "python/_autosummary/mlx.core.random.uniform", "python/_autosummary/mlx.core.reciprocal", "python/_autosummary/mlx.core.repeat", "python/_autosummary/mlx.core.reshape", "python/_autosummary/mlx.core.round", "python/_autosummary/mlx.core.rsqrt", "python/_autosummary/mlx.core.save", "python/_autosummary/mlx.core.save_gguf", "python/_autosummary/mlx.core.save_safetensors", "python/_autosummary/mlx.core.savez", "python/_autosummary/mlx.core.savez_compressed", "python/_autosummary/mlx.core.set_default_device", "python/_autosummary/mlx.core.set_default_stream", "python/_autosummary/mlx.core.sigmoid", "python/_autosummary/mlx.core.sign", "python/_autosummary/mlx.core.sin", "python/_autosummary/mlx.core.sinh", "python/_autosummary/mlx.core.softmax", "python/_autosummary/mlx.core.sort", "python/_autosummary/mlx.core.split", "python/_autosummary/mlx.core.sqrt", "python/_autosummary/mlx.core.square", "python/_autosummary/mlx.core.squeeze", "python/_autosummary/mlx.core.stack", "python/_autosummary/mlx.core.stop_gradient", "python/_autosummary/mlx.core.subtract", "python/_autosummary/mlx.core.sum", "python/_autosummary/mlx.core.swapaxes", "python/_autosummary/mlx.core.take", "python/_autosummary/mlx.core.take_along_axis", "python/_autosummary/mlx.core.tan", "python/_autosummary/mlx.core.tanh", "python/_autosummary/mlx.core.tensordot", "python/_autosummary/mlx.core.transpose", "python/_autosummary/mlx.core.tri", "python/_autosummary/mlx.core.tril", "python/_autosummary/mlx.core.triu", "python/_autosummary/mlx.core.value_and_grad", "python/_autosummary/mlx.core.var", "python/_autosummary/mlx.core.vjp", "python/_autosummary/mlx.core.vmap", "python/_autosummary/mlx.core.where", "python/_autosummary/mlx.core.zeros", "python/_autosummary/mlx.core.zeros_like", "python/_autosummary/mlx.nn.value_and_grad", "python/_autosummary/mlx.optimizers.AdaDelta", "python/_autosummary/mlx.optimizers.Adafactor", "python/_autosummary/mlx.optimizers.Adagrad", "python/_autosummary/mlx.optimizers.Adam", "python/_autosummary/mlx.optimizers.AdamW", "python/_autosummary/mlx.optimizers.Adamax", "python/_autosummary/mlx.optimizers.Lion", "python/_autosummary/mlx.optimizers.Optimizer", "python/_autosummary/mlx.optimizers.OptimizerState", "python/_autosummary/mlx.optimizers.RMSprop", "python/_autosummary/mlx.optimizers.SGD", "python/_autosummary/mlx.utils.tree_flatten", "python/_autosummary/mlx.utils.tree_map", "python/_autosummary/mlx.utils.tree_unflatten", "python/array", "python/data_types", "python/devices_and_streams", "python/fft", "python/linalg", "python/nn", "python/nn/_autosummary/mlx.nn.ALiBi", "python/nn/_autosummary/mlx.nn.BatchNorm", "python/nn/_autosummary/mlx.nn.Conv1d", "python/nn/_autosummary/mlx.nn.Conv2d", "python/nn/_autosummary/mlx.nn.Dropout", "python/nn/_autosummary/mlx.nn.Dropout2d", "python/nn/_autosummary/mlx.nn.Dropout3d", "python/nn/_autosummary/mlx.nn.Embedding", "python/nn/_autosummary/mlx.nn.GELU", "python/nn/_autosummary/mlx.nn.GroupNorm", "python/nn/_autosummary/mlx.nn.InstanceNorm", "python/nn/_autosummary/mlx.nn.LayerNorm", "python/nn/_autosummary/mlx.nn.Linear", "python/nn/_autosummary/mlx.nn.Mish", "python/nn/_autosummary/mlx.nn.Module.apply", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules", "python/nn/_autosummary/mlx.nn.Module.children", "python/nn/_autosummary/mlx.nn.Module.eval", "python/nn/_autosummary/mlx.nn.Module.filter_and_map", "python/nn/_autosummary/mlx.nn.Module.freeze", "python/nn/_autosummary/mlx.nn.Module.leaf_modules", "python/nn/_autosummary/mlx.nn.Module.load_weights", "python/nn/_autosummary/mlx.nn.Module.modules", "python/nn/_autosummary/mlx.nn.Module.named_modules", "python/nn/_autosummary/mlx.nn.Module.parameters", "python/nn/_autosummary/mlx.nn.Module.save_weights", "python/nn/_autosummary/mlx.nn.Module.train", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters", "python/nn/_autosummary/mlx.nn.Module.training", "python/nn/_autosummary/mlx.nn.Module.unfreeze", "python/nn/_autosummary/mlx.nn.Module.update", "python/nn/_autosummary/mlx.nn.Module.update_modules", "python/nn/_autosummary/mlx.nn.MultiHeadAttention", "python/nn/_autosummary/mlx.nn.PReLU", "python/nn/_autosummary/mlx.nn.QuantizedLinear", "python/nn/_autosummary/mlx.nn.RMSNorm", "python/nn/_autosummary/mlx.nn.ReLU", "python/nn/_autosummary/mlx.nn.RoPE", "python/nn/_autosummary/mlx.nn.SELU", "python/nn/_autosummary/mlx.nn.Sequential", "python/nn/_autosummary/mlx.nn.SiLU", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding", "python/nn/_autosummary/mlx.nn.Softshrink", "python/nn/_autosummary/mlx.nn.Step", "python/nn/_autosummary/mlx.nn.Transformer", "python/nn/_autosummary/mlx.nn.init.constant", "python/nn/_autosummary/mlx.nn.init.glorot_normal", "python/nn/_autosummary/mlx.nn.init.glorot_uniform", "python/nn/_autosummary/mlx.nn.init.he_normal", "python/nn/_autosummary/mlx.nn.init.he_uniform", "python/nn/_autosummary/mlx.nn.init.identity", "python/nn/_autosummary/mlx.nn.init.normal", "python/nn/_autosummary/mlx.nn.init.uniform", "python/nn/_autosummary_functions/mlx.nn.gelu", "python/nn/_autosummary_functions/mlx.nn.gelu_approx", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss", "python/nn/_autosummary_functions/mlx.nn.mish", "python/nn/_autosummary_functions/mlx.nn.prelu", "python/nn/_autosummary_functions/mlx.nn.relu", "python/nn/_autosummary_functions/mlx.nn.selu", "python/nn/_autosummary_functions/mlx.nn.silu", "python/nn/_autosummary_functions/mlx.nn.softshrink", "python/nn/_autosummary_functions/mlx.nn.step", "python/nn/functions", "python/nn/init", "python/nn/layers", "python/nn/losses", "python/nn/module", "python/ops", "python/optimizers", "python/random", "python/transforms", "python/tree_utils", "usage/function_transforms", "usage/indexing", "usage/lazy_evaluation", "usage/numpy", "usage/quick_start", "usage/saving_and_loading", "usage/unified_memory", "usage/using_streams"], "filenames": ["cpp/ops.rst", "dev/extensions.rst", "examples/linear_regression.rst", "examples/llama-inference.rst", "examples/mlp.rst", "index.rst", "install.rst", "python/_autosummary/mlx.core.Device.rst", "python/_autosummary/mlx.core.Dtype.rst", "python/_autosummary/mlx.core.Stream.rst", "python/_autosummary/mlx.core.abs.rst", "python/_autosummary/mlx.core.add.rst", "python/_autosummary/mlx.core.all.rst", "python/_autosummary/mlx.core.allclose.rst", "python/_autosummary/mlx.core.any.rst", "python/_autosummary/mlx.core.arange.rst", "python/_autosummary/mlx.core.arccos.rst", "python/_autosummary/mlx.core.arccosh.rst", "python/_autosummary/mlx.core.arcsin.rst", "python/_autosummary/mlx.core.arcsinh.rst", "python/_autosummary/mlx.core.arctan.rst", "python/_autosummary/mlx.core.arctanh.rst", "python/_autosummary/mlx.core.argmax.rst", "python/_autosummary/mlx.core.argmin.rst", "python/_autosummary/mlx.core.argpartition.rst", "python/_autosummary/mlx.core.argsort.rst", "python/_autosummary/mlx.core.array.rst", "python/_autosummary/mlx.core.array.T.rst", "python/_autosummary/mlx.core.array.abs.rst", "python/_autosummary/mlx.core.array.all.rst", "python/_autosummary/mlx.core.array.any.rst", "python/_autosummary/mlx.core.array.argmax.rst", "python/_autosummary/mlx.core.array.argmin.rst", "python/_autosummary/mlx.core.array.astype.rst", "python/_autosummary/mlx.core.array.cos.rst", "python/_autosummary/mlx.core.array.dtype.rst", "python/_autosummary/mlx.core.array.exp.rst", "python/_autosummary/mlx.core.array.item.rst", "python/_autosummary/mlx.core.array.log.rst", "python/_autosummary/mlx.core.array.log1p.rst", "python/_autosummary/mlx.core.array.logsumexp.rst", "python/_autosummary/mlx.core.array.max.rst", "python/_autosummary/mlx.core.array.mean.rst", "python/_autosummary/mlx.core.array.min.rst", "python/_autosummary/mlx.core.array.ndim.rst", "python/_autosummary/mlx.core.array.prod.rst", "python/_autosummary/mlx.core.array.reciprocal.rst", "python/_autosummary/mlx.core.array.reshape.rst", "python/_autosummary/mlx.core.array.round.rst", "python/_autosummary/mlx.core.array.rsqrt.rst", "python/_autosummary/mlx.core.array.shape.rst", "python/_autosummary/mlx.core.array.sin.rst", "python/_autosummary/mlx.core.array.size.rst", "python/_autosummary/mlx.core.array.split.rst", "python/_autosummary/mlx.core.array.sqrt.rst", "python/_autosummary/mlx.core.array.square.rst", "python/_autosummary/mlx.core.array.sum.rst", "python/_autosummary/mlx.core.array.tolist.rst", "python/_autosummary/mlx.core.array.transpose.rst", "python/_autosummary/mlx.core.array.var.rst", "python/_autosummary/mlx.core.array_equal.rst", "python/_autosummary/mlx.core.broadcast_to.rst", "python/_autosummary/mlx.core.ceil.rst", "python/_autosummary/mlx.core.clip.rst", "python/_autosummary/mlx.core.concatenate.rst", "python/_autosummary/mlx.core.conv1d.rst", "python/_autosummary/mlx.core.conv2d.rst", "python/_autosummary/mlx.core.convolve.rst", "python/_autosummary/mlx.core.cos.rst", "python/_autosummary/mlx.core.cosh.rst", "python/_autosummary/mlx.core.default_device.rst", "python/_autosummary/mlx.core.default_stream.rst", "python/_autosummary/mlx.core.dequantize.rst", "python/_autosummary/mlx.core.diag.rst", "python/_autosummary/mlx.core.diagonal.rst", "python/_autosummary/mlx.core.divide.rst", "python/_autosummary/mlx.core.divmod.rst", "python/_autosummary/mlx.core.equal.rst", "python/_autosummary/mlx.core.erf.rst", "python/_autosummary/mlx.core.erfinv.rst", "python/_autosummary/mlx.core.eval.rst", "python/_autosummary/mlx.core.exp.rst", "python/_autosummary/mlx.core.expand_dims.rst", "python/_autosummary/mlx.core.eye.rst", "python/_autosummary/mlx.core.fft.fft.rst", "python/_autosummary/mlx.core.fft.fft2.rst", "python/_autosummary/mlx.core.fft.fftn.rst", "python/_autosummary/mlx.core.fft.ifft.rst", "python/_autosummary/mlx.core.fft.ifft2.rst", "python/_autosummary/mlx.core.fft.ifftn.rst", "python/_autosummary/mlx.core.fft.irfft.rst", "python/_autosummary/mlx.core.fft.irfft2.rst", "python/_autosummary/mlx.core.fft.irfftn.rst", "python/_autosummary/mlx.core.fft.rfft.rst", "python/_autosummary/mlx.core.fft.rfft2.rst", "python/_autosummary/mlx.core.fft.rfftn.rst", "python/_autosummary/mlx.core.flatten.rst", "python/_autosummary/mlx.core.floor.rst", "python/_autosummary/mlx.core.floor_divide.rst", "python/_autosummary/mlx.core.full.rst", "python/_autosummary/mlx.core.grad.rst", "python/_autosummary/mlx.core.greater.rst", "python/_autosummary/mlx.core.greater_equal.rst", "python/_autosummary/mlx.core.identity.rst", "python/_autosummary/mlx.core.inner.rst", "python/_autosummary/mlx.core.isinf.rst", "python/_autosummary/mlx.core.isnan.rst", "python/_autosummary/mlx.core.isneginf.rst", "python/_autosummary/mlx.core.isposinf.rst", "python/_autosummary/mlx.core.jvp.rst", "python/_autosummary/mlx.core.less.rst", "python/_autosummary/mlx.core.less_equal.rst", "python/_autosummary/mlx.core.linalg.norm.rst", "python/_autosummary/mlx.core.linalg.qr.rst", "python/_autosummary/mlx.core.linspace.rst", "python/_autosummary/mlx.core.load.rst", "python/_autosummary/mlx.core.log.rst", "python/_autosummary/mlx.core.log10.rst", "python/_autosummary/mlx.core.log1p.rst", "python/_autosummary/mlx.core.log2.rst", "python/_autosummary/mlx.core.logaddexp.rst", "python/_autosummary/mlx.core.logical_and.rst", "python/_autosummary/mlx.core.logical_not.rst", "python/_autosummary/mlx.core.logical_or.rst", "python/_autosummary/mlx.core.logsumexp.rst", "python/_autosummary/mlx.core.matmul.rst", "python/_autosummary/mlx.core.max.rst", "python/_autosummary/mlx.core.maximum.rst", "python/_autosummary/mlx.core.mean.rst", "python/_autosummary/mlx.core.min.rst", "python/_autosummary/mlx.core.minimum.rst", "python/_autosummary/mlx.core.moveaxis.rst", "python/_autosummary/mlx.core.multiply.rst", "python/_autosummary/mlx.core.negative.rst", "python/_autosummary/mlx.core.new_stream.rst", "python/_autosummary/mlx.core.ones.rst", "python/_autosummary/mlx.core.ones_like.rst", "python/_autosummary/mlx.core.outer.rst", "python/_autosummary/mlx.core.pad.rst", "python/_autosummary/mlx.core.partition.rst", "python/_autosummary/mlx.core.prod.rst", "python/_autosummary/mlx.core.quantize.rst", "python/_autosummary/mlx.core.quantized_matmul.rst", "python/_autosummary/mlx.core.random.bernoulli.rst", "python/_autosummary/mlx.core.random.categorical.rst", "python/_autosummary/mlx.core.random.gumbel.rst", "python/_autosummary/mlx.core.random.key.rst", "python/_autosummary/mlx.core.random.normal.rst", "python/_autosummary/mlx.core.random.randint.rst", "python/_autosummary/mlx.core.random.seed.rst", "python/_autosummary/mlx.core.random.split.rst", "python/_autosummary/mlx.core.random.truncated_normal.rst", "python/_autosummary/mlx.core.random.uniform.rst", "python/_autosummary/mlx.core.reciprocal.rst", "python/_autosummary/mlx.core.repeat.rst", "python/_autosummary/mlx.core.reshape.rst", "python/_autosummary/mlx.core.round.rst", "python/_autosummary/mlx.core.rsqrt.rst", "python/_autosummary/mlx.core.save.rst", "python/_autosummary/mlx.core.save_gguf.rst", "python/_autosummary/mlx.core.save_safetensors.rst", "python/_autosummary/mlx.core.savez.rst", "python/_autosummary/mlx.core.savez_compressed.rst", "python/_autosummary/mlx.core.set_default_device.rst", "python/_autosummary/mlx.core.set_default_stream.rst", "python/_autosummary/mlx.core.sigmoid.rst", "python/_autosummary/mlx.core.sign.rst", "python/_autosummary/mlx.core.sin.rst", "python/_autosummary/mlx.core.sinh.rst", "python/_autosummary/mlx.core.softmax.rst", "python/_autosummary/mlx.core.sort.rst", "python/_autosummary/mlx.core.split.rst", "python/_autosummary/mlx.core.sqrt.rst", "python/_autosummary/mlx.core.square.rst", "python/_autosummary/mlx.core.squeeze.rst", "python/_autosummary/mlx.core.stack.rst", "python/_autosummary/mlx.core.stop_gradient.rst", "python/_autosummary/mlx.core.subtract.rst", "python/_autosummary/mlx.core.sum.rst", "python/_autosummary/mlx.core.swapaxes.rst", "python/_autosummary/mlx.core.take.rst", "python/_autosummary/mlx.core.take_along_axis.rst", "python/_autosummary/mlx.core.tan.rst", "python/_autosummary/mlx.core.tanh.rst", "python/_autosummary/mlx.core.tensordot.rst", "python/_autosummary/mlx.core.transpose.rst", "python/_autosummary/mlx.core.tri.rst", "python/_autosummary/mlx.core.tril.rst", "python/_autosummary/mlx.core.triu.rst", "python/_autosummary/mlx.core.value_and_grad.rst", "python/_autosummary/mlx.core.var.rst", "python/_autosummary/mlx.core.vjp.rst", "python/_autosummary/mlx.core.vmap.rst", "python/_autosummary/mlx.core.where.rst", "python/_autosummary/mlx.core.zeros.rst", "python/_autosummary/mlx.core.zeros_like.rst", "python/_autosummary/mlx.nn.value_and_grad.rst", "python/_autosummary/mlx.optimizers.AdaDelta.rst", "python/_autosummary/mlx.optimizers.Adafactor.rst", "python/_autosummary/mlx.optimizers.Adagrad.rst", "python/_autosummary/mlx.optimizers.Adam.rst", "python/_autosummary/mlx.optimizers.AdamW.rst", "python/_autosummary/mlx.optimizers.Adamax.rst", "python/_autosummary/mlx.optimizers.Lion.rst", "python/_autosummary/mlx.optimizers.Optimizer.rst", "python/_autosummary/mlx.optimizers.OptimizerState.rst", "python/_autosummary/mlx.optimizers.RMSprop.rst", "python/_autosummary/mlx.optimizers.SGD.rst", "python/_autosummary/mlx.utils.tree_flatten.rst", "python/_autosummary/mlx.utils.tree_map.rst", "python/_autosummary/mlx.utils.tree_unflatten.rst", "python/array.rst", "python/data_types.rst", "python/devices_and_streams.rst", "python/fft.rst", "python/linalg.rst", "python/nn.rst", "python/nn/_autosummary/mlx.nn.ALiBi.rst", "python/nn/_autosummary/mlx.nn.BatchNorm.rst", "python/nn/_autosummary/mlx.nn.Conv1d.rst", "python/nn/_autosummary/mlx.nn.Conv2d.rst", "python/nn/_autosummary/mlx.nn.Dropout.rst", "python/nn/_autosummary/mlx.nn.Dropout2d.rst", "python/nn/_autosummary/mlx.nn.Dropout3d.rst", "python/nn/_autosummary/mlx.nn.Embedding.rst", "python/nn/_autosummary/mlx.nn.GELU.rst", "python/nn/_autosummary/mlx.nn.GroupNorm.rst", "python/nn/_autosummary/mlx.nn.InstanceNorm.rst", "python/nn/_autosummary/mlx.nn.LayerNorm.rst", "python/nn/_autosummary/mlx.nn.Linear.rst", "python/nn/_autosummary/mlx.nn.Mish.rst", "python/nn/_autosummary/mlx.nn.Module.apply.rst", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules.rst", "python/nn/_autosummary/mlx.nn.Module.children.rst", "python/nn/_autosummary/mlx.nn.Module.eval.rst", "python/nn/_autosummary/mlx.nn.Module.filter_and_map.rst", "python/nn/_autosummary/mlx.nn.Module.freeze.rst", "python/nn/_autosummary/mlx.nn.Module.leaf_modules.rst", "python/nn/_autosummary/mlx.nn.Module.load_weights.rst", "python/nn/_autosummary/mlx.nn.Module.modules.rst", "python/nn/_autosummary/mlx.nn.Module.named_modules.rst", "python/nn/_autosummary/mlx.nn.Module.parameters.rst", "python/nn/_autosummary/mlx.nn.Module.save_weights.rst", "python/nn/_autosummary/mlx.nn.Module.train.rst", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters.rst", "python/nn/_autosummary/mlx.nn.Module.training.rst", "python/nn/_autosummary/mlx.nn.Module.unfreeze.rst", "python/nn/_autosummary/mlx.nn.Module.update.rst", "python/nn/_autosummary/mlx.nn.Module.update_modules.rst", "python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst", "python/nn/_autosummary/mlx.nn.PReLU.rst", "python/nn/_autosummary/mlx.nn.QuantizedLinear.rst", "python/nn/_autosummary/mlx.nn.RMSNorm.rst", "python/nn/_autosummary/mlx.nn.ReLU.rst", "python/nn/_autosummary/mlx.nn.RoPE.rst", "python/nn/_autosummary/mlx.nn.SELU.rst", "python/nn/_autosummary/mlx.nn.Sequential.rst", "python/nn/_autosummary/mlx.nn.SiLU.rst", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst", "python/nn/_autosummary/mlx.nn.Softshrink.rst", "python/nn/_autosummary/mlx.nn.Step.rst", "python/nn/_autosummary/mlx.nn.Transformer.rst", "python/nn/_autosummary/mlx.nn.init.constant.rst", "python/nn/_autosummary/mlx.nn.init.glorot_normal.rst", "python/nn/_autosummary/mlx.nn.init.glorot_uniform.rst", "python/nn/_autosummary/mlx.nn.init.he_normal.rst", "python/nn/_autosummary/mlx.nn.init.he_uniform.rst", "python/nn/_autosummary/mlx.nn.init.identity.rst", "python/nn/_autosummary/mlx.nn.init.normal.rst", "python/nn/_autosummary/mlx.nn.init.uniform.rst", "python/nn/_autosummary_functions/mlx.nn.gelu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst", "python/nn/_autosummary_functions/mlx.nn.mish.rst", "python/nn/_autosummary_functions/mlx.nn.prelu.rst", "python/nn/_autosummary_functions/mlx.nn.relu.rst", "python/nn/_autosummary_functions/mlx.nn.selu.rst", "python/nn/_autosummary_functions/mlx.nn.silu.rst", "python/nn/_autosummary_functions/mlx.nn.softshrink.rst", "python/nn/_autosummary_functions/mlx.nn.step.rst", "python/nn/functions.rst", "python/nn/init.rst", "python/nn/layers.rst", "python/nn/losses.rst", "python/nn/module.rst", "python/ops.rst", "python/optimizers.rst", "python/random.rst", "python/transforms.rst", "python/tree_utils.rst", "usage/function_transforms.rst", "usage/indexing.rst", "usage/lazy_evaluation.rst", "usage/numpy.rst", "usage/quick_start.rst", "usage/saving_and_loading.rst", "usage/unified_memory.rst", "usage/using_streams.rst"], "titles": ["Operations", "Developer Documentation", "Linear Regression", "LLM inference", "Multi-Layer Perceptron", "MLX", "Build and Install", "mlx.core.Device", "mlx.core.Dtype", "mlx.core.Stream", "mlx.core.abs", "mlx.core.add", "mlx.core.all", "mlx.core.allclose", "mlx.core.any", "mlx.core.arange", "mlx.core.arccos", "mlx.core.arccosh", "mlx.core.arcsin", "mlx.core.arcsinh", "mlx.core.arctan", "mlx.core.arctanh", "mlx.core.argmax", "mlx.core.argmin", "mlx.core.argpartition", "mlx.core.argsort", "mlx.core.array", "mlx.core.array.T", "mlx.core.array.abs", "mlx.core.array.all", "mlx.core.array.any", "mlx.core.array.argmax", "mlx.core.array.argmin", "mlx.core.array.astype", "mlx.core.array.cos", "mlx.core.array.dtype", "mlx.core.array.exp", "mlx.core.array.item", "mlx.core.array.log", "mlx.core.array.log1p", "mlx.core.array.logsumexp", "mlx.core.array.max", "mlx.core.array.mean", "mlx.core.array.min", "mlx.core.array.ndim", "mlx.core.array.prod", "mlx.core.array.reciprocal", "mlx.core.array.reshape", "mlx.core.array.round", "mlx.core.array.rsqrt", "mlx.core.array.shape", "mlx.core.array.sin", "mlx.core.array.size", "mlx.core.array.split", "mlx.core.array.sqrt", "mlx.core.array.square", "mlx.core.array.sum", "mlx.core.array.tolist", "mlx.core.array.transpose", "mlx.core.array.var", "mlx.core.array_equal", "mlx.core.broadcast_to", "mlx.core.ceil", "mlx.core.clip", "mlx.core.concatenate", "mlx.core.conv1d", "mlx.core.conv2d", "mlx.core.convolve", "mlx.core.cos", "mlx.core.cosh", "mlx.core.default_device", "mlx.core.default_stream", "mlx.core.dequantize", "mlx.core.diag", "mlx.core.diagonal", "mlx.core.divide", "mlx.core.divmod", "mlx.core.equal", "mlx.core.erf", "mlx.core.erfinv", "mlx.core.eval", "mlx.core.exp", "mlx.core.expand_dims", "mlx.core.eye", "mlx.core.fft.fft", "mlx.core.fft.fft2", "mlx.core.fft.fftn", "mlx.core.fft.ifft", "mlx.core.fft.ifft2", "mlx.core.fft.ifftn", "mlx.core.fft.irfft", "mlx.core.fft.irfft2", "mlx.core.fft.irfftn", "mlx.core.fft.rfft", "mlx.core.fft.rfft2", "mlx.core.fft.rfftn", "mlx.core.flatten", "mlx.core.floor", "mlx.core.floor_divide", "mlx.core.full", "mlx.core.grad", "mlx.core.greater", "mlx.core.greater_equal", "mlx.core.identity", "mlx.core.inner", "mlx.core.isinf", "mlx.core.isnan", "mlx.core.isneginf", "mlx.core.isposinf", "mlx.core.jvp", "mlx.core.less", "mlx.core.less_equal", "mlx.core.linalg.norm", "mlx.core.linalg.qr", "mlx.core.linspace", "mlx.core.load", "mlx.core.log", "mlx.core.log10", "mlx.core.log1p", "mlx.core.log2", "mlx.core.logaddexp", "mlx.core.logical_and", "mlx.core.logical_not", "mlx.core.logical_or", "mlx.core.logsumexp", "mlx.core.matmul", "mlx.core.max", "mlx.core.maximum", "mlx.core.mean", "mlx.core.min", "mlx.core.minimum", "mlx.core.moveaxis", "mlx.core.multiply", "mlx.core.negative", "mlx.core.new_stream", "mlx.core.ones", "mlx.core.ones_like", "mlx.core.outer", "mlx.core.pad", "mlx.core.partition", "mlx.core.prod", "mlx.core.quantize", "mlx.core.quantized_matmul", "mlx.core.random.bernoulli", "mlx.core.random.categorical", "mlx.core.random.gumbel", "mlx.core.random.key", "mlx.core.random.normal", "mlx.core.random.randint", "mlx.core.random.seed", "mlx.core.random.split", "mlx.core.random.truncated_normal", "mlx.core.random.uniform", "mlx.core.reciprocal", "mlx.core.repeat", "mlx.core.reshape", "mlx.core.round", "mlx.core.rsqrt", "mlx.core.save", "mlx.core.save_gguf", "mlx.core.save_safetensors", "mlx.core.savez", "mlx.core.savez_compressed", "mlx.core.set_default_device", "mlx.core.set_default_stream", "mlx.core.sigmoid", "mlx.core.sign", "mlx.core.sin", "mlx.core.sinh", "mlx.core.softmax", "mlx.core.sort", "mlx.core.split", "mlx.core.sqrt", "mlx.core.square", "mlx.core.squeeze", "mlx.core.stack", "mlx.core.stop_gradient", "mlx.core.subtract", "mlx.core.sum", "mlx.core.swapaxes", "mlx.core.take", "mlx.core.take_along_axis", "mlx.core.tan", "mlx.core.tanh", "mlx.core.tensordot", "mlx.core.transpose", "mlx.core.tri", "mlx.core.tril", "mlx.core.triu", "mlx.core.value_and_grad", "mlx.core.var", "mlx.core.vjp", "mlx.core.vmap", "mlx.core.where", "mlx.core.zeros", "mlx.core.zeros_like", "mlx.nn.value_and_grad", "mlx.optimizers.AdaDelta", "mlx.optimizers.Adafactor", "mlx.optimizers.Adagrad", "mlx.optimizers.Adam", "mlx.optimizers.AdamW", "mlx.optimizers.Adamax", "mlx.optimizers.Lion", "mlx.optimizers.Optimizer", "mlx.optimizers.OptimizerState", "mlx.optimizers.RMSprop", "mlx.optimizers.SGD", "mlx.utils.tree_flatten", "mlx.utils.tree_map", "mlx.utils.tree_unflatten", "Array", "Data Types", "Devices and Streams", "FFT", "Linear Algebra", "Neural Networks", "mlx.nn.ALiBi", "mlx.nn.BatchNorm", "mlx.nn.Conv1d", "mlx.nn.Conv2d", "mlx.nn.Dropout", "mlx.nn.Dropout2d", "mlx.nn.Dropout3d", "mlx.nn.Embedding", "mlx.nn.GELU", "mlx.nn.GroupNorm", "mlx.nn.InstanceNorm", "mlx.nn.LayerNorm", "mlx.nn.Linear", "mlx.nn.Mish", "mlx.nn.Module.apply", "mlx.nn.Module.apply_to_modules", "mlx.nn.Module.children", "mlx.nn.Module.eval", "mlx.nn.Module.filter_and_map", "mlx.nn.Module.freeze", "mlx.nn.Module.leaf_modules", "mlx.nn.Module.load_weights", "mlx.nn.Module.modules", "mlx.nn.Module.named_modules", "mlx.nn.Module.parameters", "mlx.nn.Module.save_weights", "mlx.nn.Module.train", "mlx.nn.Module.trainable_parameters", "mlx.nn.Module.training", "mlx.nn.Module.unfreeze", "mlx.nn.Module.update", "mlx.nn.Module.update_modules", "mlx.nn.MultiHeadAttention", "mlx.nn.PReLU", "mlx.nn.QuantizedLinear", "mlx.nn.RMSNorm", "mlx.nn.ReLU", "mlx.nn.RoPE", "mlx.nn.SELU", "mlx.nn.Sequential", "mlx.nn.SiLU", "mlx.nn.SinusoidalPositionalEncoding", "mlx.nn.Softshrink", "mlx.nn.Step", "mlx.nn.Transformer", "mlx.nn.init.constant", "mlx.nn.init.glorot_normal", "mlx.nn.init.glorot_uniform", "mlx.nn.init.he_normal", "mlx.nn.init.he_uniform", "mlx.nn.init.identity", "mlx.nn.init.normal", "mlx.nn.init.uniform", "mlx.nn.gelu", "mlx.nn.gelu_approx", "mlx.nn.gelu_fast_approx", "mlx.nn.losses.binary_cross_entropy", "mlx.nn.losses.cosine_similarity_loss", "mlx.nn.losses.cross_entropy", "mlx.nn.losses.gaussian_nll_loss", "mlx.nn.losses.hinge_loss", "mlx.nn.losses.huber_loss", "mlx.nn.losses.kl_div_loss", "mlx.nn.losses.l1_loss", "mlx.nn.losses.log_cosh_loss", "mlx.nn.losses.mse_loss", "mlx.nn.losses.nll_loss", "mlx.nn.losses.smooth_l1_loss", "mlx.nn.losses.triplet_loss", "mlx.nn.mish", "mlx.nn.prelu", "mlx.nn.relu", "mlx.nn.selu", "mlx.nn.silu", "mlx.nn.softshrink", "mlx.nn.step", "Functions", "Initializers", "Layers", "Loss Functions", "Module", "Operations", "Optimizers", "Random", "Transforms", "Tree Utils", "Function Transforms", "Indexing Arrays", "Lazy Evaluation", "Conversion to NumPy and Other Frameworks", "Quick Start Guide", "Saving and Loading Arrays", "Unified Memory", "Using Streams"], "terms": {"mlx": [1, 2, 3, 4, 6, 216, 294, 297, 299, 300, 302, 303, 304, 305, 306, 307, 308, 309], "provid": [1, 3, 72, 100, 184, 189, 209, 216, 231, 236, 238, 246, 247, 248, 251, 261, 293, 297, 308, 310], "open": [1, 6, 15, 148, 152], "flexibl": [1, 5, 248], "which": [1, 3, 4, 5, 6, 15, 33, 74, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 100, 105, 106, 107, 108, 109, 112, 113, 115, 141, 144, 145, 154, 155, 158, 159, 160, 161, 162, 174, 175, 180, 189, 191, 192, 222, 223, 225, 231, 235, 254, 275, 278, 284, 294, 300, 303, 304, 305, 309, 310], "user": [1, 3, 216], "mai": [1, 112, 222, 303, 304], "add": [1, 3, 82, 120, 138, 141, 219, 220, 303, 309], "special": 1, "without": [1, 3, 5, 176, 249, 293, 302, 305, 306, 309], "much": [1, 3, 305], "hassl": 1, "while": [1, 3, 6, 155, 254, 305, 306], "librari": [1, 6, 216], "suppli": 1, "effici": [1, 3, 5, 222, 254, 305, 307], "can": [1, 3, 5, 6, 11, 15, 47, 58, 74, 75, 76, 77, 80, 101, 102, 110, 111, 112, 120, 127, 130, 132, 143, 144, 148, 151, 152, 159, 177, 189, 216, 224, 235, 246, 256, 275, 294, 297, 299, 300, 302, 303, 304, 305, 306, 307, 308, 309, 310], "compos": [1, 5, 216, 303, 307], "ani": [1, 3, 5, 15, 208, 209, 210, 216, 225, 231, 232, 235, 251, 261, 294, 302, 303, 305, 307, 308, 309], "number": [1, 15, 52, 66, 72, 83, 100, 103, 109, 114, 138, 141, 142, 144, 147, 150, 152, 154, 156, 184, 186, 189, 191, 192, 216, 218, 219, 220, 222, 223, 226, 227, 249, 250, 261, 263, 264, 265, 266, 300, 303, 310], "applic": [1, 6], "aris": [1, 306], "case": [1, 3, 86, 89, 90, 92, 93, 94, 95, 96, 113, 125, 155, 174, 222, 255, 260, 284, 289, 291, 292, 303, 307, 308, 309, 310], "where": [1, 4, 83, 141, 189, 192, 218, 219, 220, 221, 222, 223, 225, 226, 227, 228, 229, 235, 250, 252, 255, 257, 260, 265, 266, 270, 271, 272, 276, 287, 289, 290, 292, 303, 304], "new": [1, 4, 61, 74, 131, 155, 175, 185, 209, 249, 297, 299, 304, 305, 306], "function": [1, 2, 3, 4, 5, 13, 76, 78, 79, 100, 109, 112, 113, 125, 165, 189, 191, 192, 196, 209, 216, 225, 230, 232, 236, 246, 250, 256, 259, 260, 261, 270, 271, 272, 286, 291, 292, 294, 299, 300, 302, 304, 305, 306, 308], "highli": [1, 6], "optim": [1, 2, 4, 5, 247, 303, 305], "ar": [1, 2, 3, 4, 5, 6, 13, 15, 60, 61, 63, 67, 74, 83, 85, 86, 88, 89, 91, 92, 94, 95, 96, 100, 105, 106, 107, 108, 109, 112, 113, 115, 125, 137, 138, 139, 141, 142, 143, 144, 145, 148, 151, 152, 161, 162, 174, 175, 180, 189, 191, 192, 203, 208, 209, 218, 219, 220, 221, 222, 223, 226, 227, 228, 229, 238, 249, 251, 273, 275, 276, 293, 297, 302, 303, 304, 305, 306, 307, 308, 309], "need": [1, 3, 4, 5, 60, 141, 216, 247, 248, 258, 261, 300, 303, 305, 306, 307, 309], "For": [1, 3, 6, 112, 141, 210, 216, 218, 222, 231, 236, 243, 246, 251, 254, 258, 263, 264, 265, 266, 294, 300, 304, 305, 306, 307, 308, 309], "you": [1, 3, 4, 5, 6, 216, 258, 261, 294, 300, 303, 304, 306, 308, 309], "design": [1, 2, 5, 300, 309], "your": [1, 3, 6, 297, 303, 305], "own": [1, 6, 306], "link": [1, 6], "top": [1, 229], "core": [1, 2, 3, 4, 216, 218, 227, 238, 241, 244, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 294, 297, 299, 306, 307], "we": [1, 2, 3, 4, 72, 141, 142, 201, 203, 216, 224, 256, 300, 302, 303, 305, 309], "inner": 1, "work": [1, 3, 6, 303, 304, 305], "go": [1, 3, 303], "over": [1, 3, 4, 12, 14, 22, 23, 24, 25, 65, 66, 86, 89, 92, 95, 104, 112, 114, 124, 126, 128, 129, 139, 140, 157, 169, 170, 178, 184, 190, 218, 219, 220, 226, 228, 252, 275, 303], "simpl": [1, 3, 4, 216, 224, 293, 303, 305], "learn": [1, 2, 4, 5, 197, 198, 199, 200, 201, 202, 203, 206, 207, 218, 226, 227, 228, 250, 252], "step": [1, 3, 4, 15, 198, 216], "involv": [1, 299], "ad": [1, 2, 6, 197, 198, 199, 200, 201, 202, 206, 227, 297, 305, 308], "let": [1, 2, 3, 303, 305, 306], "s": [1, 2, 3, 4, 35, 44, 72, 85, 86, 88, 89, 91, 92, 94, 95, 100, 112, 115, 128, 137, 141, 144, 156, 159, 160, 189, 190, 192, 196, 204, 216, 235, 236, 238, 242, 246, 299, 300, 303, 305, 306, 307, 308, 309], "sai": [1, 3, 294, 305], "would": [1, 3, 304, 305, 306, 309], "like": [1, 3, 5, 136, 195, 223, 281, 303, 305, 306, 307, 309], "an": [1, 3, 4, 6, 8, 12, 14, 26, 61, 65, 66, 80, 83, 96, 99, 103, 112, 115, 126, 129, 131, 135, 136, 138, 140, 141, 142, 154, 155, 156, 171, 174, 179, 180, 181, 184, 186, 192, 194, 195, 197, 204, 205, 208, 209, 216, 221, 226, 228, 229, 231, 249, 250, 251, 261, 262, 263, 264, 265, 266, 267, 268, 269, 271, 287, 294, 300, 302, 303, 304, 305, 306, 307, 308, 309, 310], "take": [1, 3, 4, 100, 109, 127, 130, 136, 142, 181, 189, 191, 192, 195, 249, 300, 303, 304, 308, 309, 310], "two": [1, 11, 13, 60, 74, 75, 77, 85, 88, 94, 101, 102, 110, 111, 113, 120, 125, 127, 130, 132, 137, 179, 251, 274, 303, 304, 309], "arrai": [1, 3, 4, 5, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 216, 218, 231, 238, 241, 244, 250, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 292, 294, 297, 303, 305, 306, 307, 309], "x": [1, 2, 3, 4, 78, 103, 112, 142, 145, 156, 161, 165, 187, 188, 193, 203, 209, 216, 218, 225, 226, 227, 228, 229, 230, 231, 250, 252, 253, 255, 257, 258, 260, 270, 271, 272, 284, 286, 287, 288, 289, 290, 291, 292, 297, 299, 303, 304, 305, 306, 307, 309], "y": [1, 2, 3, 4, 193, 199, 216, 218, 222, 226, 227, 228, 229, 252, 277, 284, 299, 303, 305, 306], "scale": [1, 3, 72, 141, 142, 198, 222, 223, 249, 254, 255, 258, 289], "them": [1, 3, 216, 236, 246, 309], "both": [1, 11, 75, 76, 77, 101, 102, 110, 111, 112, 120, 127, 130, 132, 144, 177, 227, 299, 303, 307, 309], "some": [1, 2, 3, 4, 236, 246, 303, 305], "coeffici": [1, 197, 198, 200, 201, 202, 203], "alpha": [1, 141, 201, 206, 255, 285, 287, 289], "beta": [1, 72, 141, 200, 201, 202, 203, 218, 226, 227, 228, 284], "respect": [1, 2, 4, 100, 141, 189, 209, 216, 218, 225, 226, 227, 228, 297, 303, 307], "togeth": [1, 4, 141, 209], "get": [1, 2, 4, 6, 66, 146, 205, 216, 303, 305, 309], "z": [1, 305], "well": [1, 3, 216, 236, 246, 249, 305], "veri": [1, 3, 249, 305, 309], "easili": 1, "do": [1, 3, 6, 201, 216, 237, 246, 294, 297, 303, 305], "just": [1, 4, 304], "write": [1, 3, 216, 306], "out": [1, 6, 222, 223, 243, 303, 304], "follow": [1, 3, 4, 5, 6, 15, 67, 72, 112, 141, 197, 198, 199, 200, 201, 202, 203, 207, 216, 271, 272, 279, 300, 303, 309], "import": [1, 2, 3, 4, 6, 112, 161, 189, 208, 209, 210, 216, 218, 227, 238, 273, 275, 294, 297, 303, 304, 305, 306, 307], "mx": [1, 2, 3, 4, 96, 112, 113, 115, 161, 189, 216, 218, 227, 231, 238, 242, 253, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 279, 288, 294, 297, 299, 300, 303, 304, 305, 306, 307, 308, 309, 310], "def": [1, 2, 3, 4, 189, 216, 297, 303, 304, 305, 306, 309], "simple_axpbi": 1, "float": [1, 13, 15, 57, 98, 99, 112, 142, 143, 148, 151, 152, 197, 198, 199, 200, 201, 202, 203, 206, 207, 212, 218, 221, 222, 223, 226, 227, 228, 231, 252, 254, 258, 260, 261, 262, 263, 264, 265, 266, 268, 269, 274, 275, 276, 278, 284, 285, 291, 292], "return": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 37, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 208, 209, 210, 216, 233, 235, 237, 239, 240, 241, 244, 251, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 294, 297, 302, 303, 304, 305, 306, 308, 309], "thi": [1, 3, 4, 6, 12, 13, 14, 15, 22, 23, 24, 25, 109, 112, 113, 120, 124, 125, 126, 128, 129, 139, 140, 144, 169, 170, 171, 178, 180, 190, 216, 221, 222, 223, 232, 233, 235, 236, 239, 240, 241, 244, 246, 247, 248, 249, 251, 260, 263, 264, 265, 266, 271, 272, 281, 292, 297, 302, 303, 305, 306, 308], "perform": [1, 3, 5, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 125, 142, 156, 169, 180, 216, 226, 261, 265, 266, 304, 305, 309], "leav": [1, 209], "differenti": [1, 5], "howev": [1, 216, 225, 226, 300, 305, 306], "vector": [1, 2, 5, 104, 109, 112, 180, 191, 192, 224, 275, 307], "math": [1, 3, 285], "often": [1, 223], "realiz": 1, "axpbi": 1, "routin": 1, "defin": [1, 2, 3, 4, 6, 112, 142, 205, 208, 306], "same": [1, 3, 6, 13, 60, 61, 66, 67, 90, 93, 94, 95, 100, 109, 138, 144, 156, 191, 193, 216, 218, 221, 226, 227, 251, 262, 263, 264, 265, 266, 267, 268, 269, 275, 285, 297, 300, 304, 309], "realli": 1, "part": [1, 303, 304], "doe": [1, 3, 6, 216, 304, 305, 306], "fast": [1, 225, 272, 309], "so": [1, 3, 6, 100, 189, 221, 299, 305, 309], "decid": [1, 209, 235], "want": [1, 3, 303, 309], "reli": 1, "acceler": [1, 218], "framework": [1, 5], "continu": [1, 303], "impos": 1, "our": [1, 3, 4, 197, 198, 199, 200, 202, 203, 256], "assumpt": 1, "also": [1, 3, 4, 5, 6, 11, 75, 76, 77, 86, 89, 92, 95, 101, 102, 110, 111, 120, 127, 130, 132, 141, 177, 196, 205, 216, 235, 247, 249, 251, 255, 257, 270, 289, 290, 293, 299, 303, 304, 305, 306, 307, 310], "assum": [1, 3, 113, 209, 216, 226], "how": [1, 3, 4, 216, 219, 220, 224, 304, 309], "gradient": [1, 2, 4, 100, 176, 189, 196, 197, 198, 200, 201, 202, 203, 207, 216, 236, 247, 251, 261, 281, 297, 299, 303, 304, 305, 306, 307], "ins": 1, "what": [1, 3, 209], "coincid": 1, "right": [1, 6, 141, 225, 271, 272, 276, 278, 285], "place": [1, 3, 156, 305, 306], "cours": [1, 303], "The": [1, 3, 4, 5, 6, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 35, 44, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 159, 160, 165, 166, 167, 168, 169, 170, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 212, 218, 219, 220, 221, 222, 223, 224, 226, 227, 228, 229, 232, 238, 242, 247, 248, 249, 251, 252, 254, 256, 258, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 294, 297, 299, 303, 304, 305, 306, 307, 308, 309, 310], "structur": [1, 303], "from": [1, 3, 4, 5, 72, 74, 91, 92, 94, 95, 99, 112, 115, 125, 136, 141, 143, 144, 145, 146, 148, 151, 161, 174, 176, 177, 180, 181, 193, 195, 208, 209, 210, 216, 229, 236, 238, 249, 263, 264, 265, 266, 268, 269, 276, 284, 294, 302, 303, 305, 306, 307, 308, 309], "frontend": 1, "api": [1, 303], "redirect": 1, "when": [1, 3, 5, 6, 112, 115, 219, 220, 265, 266, 279, 284, 297, 300, 309], "appropri": 1, "fallback": 1, "metal": 1, "vjp": [1, 307], "jvp": [1, 307], "In": [1, 3, 4, 125, 141, 197, 199, 200, 202, 203, 209, 216, 222, 226, 297, 302, 303, 305, 308, 309], "one": [1, 3, 6, 57, 63, 66, 82, 83, 112, 118, 125, 142, 144, 174, 177, 246, 275, 309], "sentenc": 1, "comput": [1, 2, 3, 4, 5, 6, 72, 100, 109, 112, 120, 128, 137, 141, 169, 176, 184, 189, 190, 191, 196, 197, 198, 200, 201, 202, 203, 216, 218, 226, 227, 228, 236, 247, 251, 252, 254, 261, 263, 264, 265, 266, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 299, 303, 307, 309], "graph": [1, 3, 4, 5, 303], "rule": 1, "evalu": [1, 3, 4, 5, 80, 109, 191, 216, 234, 243, 297, 299, 307], "said": [1, 3], "start": [1, 2, 3, 5, 6, 15, 114, 171, 304, 309], "discuss": 1, "more": [1, 4, 8, 57, 74, 125, 159, 160, 216, 218, 222, 254, 258, 261, 263, 264, 265, 266, 300, 303, 304, 307, 309], "detail": [1, 8, 197, 199, 200, 202, 203, 216, 222, 254, 258, 263, 264, 265, 266, 304, 307], "thei": [1, 2, 3, 13, 67, 203, 256, 277, 297, 302, 305, 307, 308, 309], "c": [1, 3, 112, 212, 218, 219, 220, 222, 223, 227, 306, 307, 309], "scalar": [1, 11, 13, 26, 37, 57, 60, 61, 63, 75, 76, 77, 98, 99, 100, 101, 102, 110, 111, 112, 114, 120, 121, 122, 123, 125, 127, 130, 132, 138, 148, 151, 152, 159, 177, 189, 193, 196, 285, 303, 305, 307], "sum": [1, 2, 11, 104, 112, 124, 169, 184, 216, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 304, 306], "element": [1, 10, 11, 16, 17, 18, 19, 20, 21, 24, 52, 62, 68, 69, 72, 75, 76, 77, 78, 79, 81, 83, 97, 98, 101, 102, 105, 106, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123, 127, 130, 132, 133, 139, 141, 142, 153, 154, 157, 165, 166, 167, 168, 172, 173, 177, 180, 182, 183, 189, 193, 221, 222, 223, 230, 250, 254, 257, 286, 287, 290, 303], "wise": [1, 10, 11, 16, 17, 18, 19, 20, 21, 62, 68, 69, 75, 76, 77, 78, 79, 81, 97, 98, 101, 102, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123, 127, 130, 132, 133, 153, 157, 165, 166, 167, 168, 172, 173, 177, 182, 183, 222, 223, 230, 250, 257, 286, 287, 290], "numpi": [1, 3, 4, 5, 11, 13, 15, 61, 75, 76, 77, 101, 102, 110, 111, 120, 125, 127, 130, 132, 177, 305, 307, 308], "style": [1, 11, 13, 75, 76, 77, 101, 102, 110, 111, 120, 125, 127, 130, 132, 177], "broadcast": [1, 11, 13, 61, 63, 75, 76, 77, 99, 101, 102, 110, 111, 120, 125, 127, 130, 132, 143, 144, 151, 152, 177, 181, 193, 249], "between": [1, 5, 63, 96, 261, 274, 277, 278, 281, 305, 309], "input": [1, 2, 3, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 73, 74, 75, 76, 77, 78, 79, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 100, 101, 102, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 136, 137, 138, 139, 140, 141, 142, 150, 153, 154, 155, 156, 157, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 187, 188, 189, 190, 192, 193, 195, 218, 219, 220, 222, 223, 224, 226, 227, 228, 229, 249, 251, 252, 254, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 276, 277, 278, 279, 281, 283, 285, 292, 294, 303, 304, 307, 308], "upcast": 1, "const": [1, 276], "factor": [1, 113, 275], "streamordevic": 1, "stream": [1, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 193, 194, 195, 309], "schedul": [1, 309], "itself": 1, "call": [1, 3, 4, 27, 98, 216, 224, 236, 246, 256, 297, 299, 303, 305], "other": [1, 3, 5, 112, 203, 216, 237, 297, 304, 305, 307], "within": [1, 24], "simplest": [1, 216], "wai": [1, 3, 6, 216, 303, 304], "about": [1, 3, 4, 305, 309], "term": [1, 197, 198, 199, 200, 201, 202, 206, 276], "exist": [1, 3, 236, 246], "auto": [1, 6], "ax": [1, 12, 14, 22, 23, 58, 82, 85, 86, 88, 89, 91, 92, 94, 95, 96, 104, 112, 124, 126, 128, 129, 138, 140, 169, 174, 178, 179, 184, 185, 190, 303], "multipli": [1, 141, 142, 221, 258], "earlier": 1, "goal": 1, "themselv": 1, "contain": [1, 3, 24, 25, 50, 74, 90, 91, 92, 112, 121, 122, 123, 141, 171, 193, 216, 235, 237, 238, 261, 294, 297, 303], "act": [1, 281], "data": [1, 4, 5, 8, 15, 83, 93, 94, 99, 103, 114, 135, 151, 186, 194, 223, 262, 263, 264, 265, 266, 267, 268, 269, 304, 306], "nor": [1, 100, 189], "rather": [1, 303, 309], "easi": [1, 216], "interfac": 1, "block": [1, 3, 261], "A": [1, 3, 5, 6, 50, 60, 100, 109, 112, 113, 115, 124, 125, 141, 143, 144, 145, 147, 148, 151, 152, 171, 175, 189, 191, 192, 196, 200, 202, 208, 209, 210, 216, 218, 222, 226, 227, 228, 230, 235, 239, 240, 247, 248, 252, 256, 258, 261, 263, 264, 266, 272, 285, 286, 297, 299, 303, 305, 306], "It": [1, 3, 6, 100, 189, 204, 216, 248, 251, 306, 308], "creat": [1, 3, 6, 83, 103, 216, 297, 299, 304, 306], "output": [1, 3, 6, 12, 13, 14, 15, 24, 61, 83, 90, 93, 94, 95, 99, 100, 103, 112, 114, 124, 126, 128, 129, 135, 136, 139, 140, 143, 144, 145, 147, 148, 151, 152, 161, 162, 169, 174, 178, 181, 186, 189, 190, 191, 192, 193, 194, 195, 218, 219, 220, 227, 229, 249, 251, 260, 261, 263, 264, 265, 266, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 294, 303, 304, 305, 306, 307, 308, 309], "given": [1, 12, 14, 24, 61, 63, 64, 72, 74, 80, 82, 84, 85, 86, 87, 88, 89, 93, 94, 95, 99, 112, 124, 126, 128, 129, 140, 148, 156, 169, 171, 178, 186, 187, 188, 190, 221, 235, 249, 274, 276], "set": [1, 3, 4, 6, 198, 205, 225, 229, 234, 236, 243, 246, 247, 251, 254, 260, 274, 285, 292, 297, 300, 303, 305], "further": [1, 6, 303], "class": [1, 3, 4, 7, 8, 9, 26, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 297], "under": [1, 112], "These": [1, 181, 275, 309], "word": 1, "bit": [1, 72, 141, 142, 212, 231, 251], "abstract": 1, "back": [1, 3, 306], "give": [1, 3, 4, 24], "ourselv": 1, "concret": [1, 229, 305, 309], "imag": [1, 220, 222, 223], "public": [1, 216], "explicit": [1, 300, 306], "alpha_": 1, "beta_": 1, "must": [1, 6, 63, 80, 99, 112, 143, 144, 148, 151, 152, 193, 306], "know": [1, 3], "popul": 1, "To": [1, 2, 3, 4, 6, 216, 294, 303, 307], "avoid": 1, "unnecessari": [1, 3], "alloc": [1, 297], "respons": 1, "space": [1, 114, 283], "void": 1, "eval_cpu": 1, "std": [1, 268], "overrid": 1, "eval_gpu": 1, "jacobian": [1, 109, 191, 307], "product": [1, 104, 109, 125, 137, 140, 184, 191, 249, 307], "primal": [1, 109, 191], "tangent": [1, 20, 21, 109, 182, 183], "int": [1, 3, 4, 7, 9, 12, 14, 15, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 48, 50, 53, 56, 57, 59, 61, 64, 65, 66, 72, 73, 74, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 99, 100, 103, 112, 114, 124, 126, 128, 129, 131, 135, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 154, 155, 156, 169, 170, 171, 174, 175, 178, 179, 180, 181, 184, 185, 186, 187, 188, 189, 190, 192, 194, 216, 218, 219, 220, 224, 226, 227, 228, 229, 249, 251, 252, 254, 258, 261, 274, 275, 279, 283, 285, 297], "argnum": [1, 100, 189, 303], "cotan": 1, "across": [1, 226], "pair": [1, 138, 238, 254], "repres": [1, 3, 285, 306], "axi": [1, 3, 4, 12, 14, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 53, 56, 59, 64, 74, 82, 84, 87, 90, 91, 92, 93, 94, 95, 96, 112, 124, 126, 128, 129, 131, 138, 139, 140, 144, 154, 169, 170, 171, 174, 175, 178, 179, 180, 181, 185, 190, 192, 274, 275, 279, 283, 285, 304], "correspond": [1, 12, 14, 57, 63, 72, 74, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 124, 126, 129, 140, 178, 184, 192, 209, 303], "dimens": [1, 3, 12, 14, 22, 23, 44, 50, 57, 66, 74, 82, 91, 92, 94, 95, 96, 104, 112, 113, 124, 125, 126, 128, 129, 140, 141, 144, 150, 178, 181, 184, 185, 190, 218, 219, 220, 222, 223, 226, 227, 228, 249, 252, 254, 261, 275, 303], "vmap": [1, 303, 305, 307], "print": [1, 2, 3, 4, 6, 208, 209, 210, 216, 300, 303, 304, 305, 306, 307], "ostream": 1, "os": [1, 6], "equival": [1, 27, 47, 58, 76, 98, 180, 225, 248, 250, 251, 259], "check": [1, 6, 60, 238, 303, 304], "bool": [1, 12, 13, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 57, 59, 60, 112, 115, 124, 126, 128, 129, 140, 142, 143, 148, 151, 152, 178, 190, 198, 207, 218, 219, 220, 226, 227, 228, 229, 231, 235, 236, 238, 243, 246, 249, 251, 254, 258, 261, 273, 276], "is_equival": 1, "privat": 1, "fall": 1, "eval": [1, 2, 3, 4, 216, 297, 299, 303, 305, 307], "deriv": [1, 303, 305], "base": [1, 112, 117, 119, 202, 204, 254, 261, 297, 299, 300, 304], "abov": [1, 3, 6, 141, 187, 201, 216, 303, 304, 305, 309], "demonstr": [1, 306], "treat": [1, 91, 92, 94, 95, 180], "paramet": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 206, 207, 208, 209, 210, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 231, 232, 235, 236, 238, 243, 246, 247, 248, 249, 250, 251, 252, 254, 256, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 293, 294, 297, 299, 303, 305], "produc": [1, 249, 294], "through": [1, 176, 203, 261, 303, 306], "construct": [1, 4, 73, 99, 135, 194], "its": [1, 6, 125, 139, 150, 186, 196, 200, 201, 202, 210, 216, 251, 306, 309], "type": [1, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 204, 208, 216, 254, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 304], "shape": [1, 3, 4, 47, 60, 61, 65, 66, 74, 84, 87, 90, 93, 94, 95, 99, 109, 125, 135, 136, 143, 144, 145, 147, 148, 151, 152, 155, 181, 191, 193, 194, 195, 216, 218, 219, 220, 222, 223, 227, 229, 238, 262, 263, 264, 265, 266, 267, 268, 269, 275, 285, 299, 303, 304, 307, 309], "pass": [1, 3, 4, 47, 58, 137, 138, 189, 196, 208, 209, 216, 236, 246, 247, 248, 251, 256, 305], "re": [1, 4, 6, 294], "now": [1, 3, 6, 251, 306], "promot": 1, "dtype": [1, 3, 15, 26, 33, 57, 83, 96, 99, 103, 112, 113, 114, 135, 145, 147, 148, 151, 152, 186, 194, 212, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 303, 304, 306, 307, 308], "promoted_dtyp": 1, "promote_typ": 1, "float32": [1, 15, 83, 103, 112, 113, 114, 135, 145, 147, 151, 152, 186, 194, 212, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 303, 304, 305, 306, 307, 308], "non": [1, 6, 230, 244, 286, 297], "point": [1, 2, 3, 6, 98, 142, 212], "out_dtyp": 1, "is_floating_point": 1, "cast": [1, 33, 93, 94, 95, 115, 231, 306], "up": [1, 3, 251], "determin": [1, 74, 242, 308], "x_cast": 1, "astyp": [1, 3, 231, 306], "y_cast": 1, "broadcasted_input": 1, "broadcast_arrai": 1, "out_shap": 1, "0": [1, 2, 3, 4, 6, 7, 15, 48, 53, 59, 64, 65, 66, 73, 74, 83, 96, 100, 112, 113, 138, 143, 152, 154, 156, 171, 175, 186, 187, 188, 189, 190, 192, 197, 198, 200, 201, 202, 203, 206, 207, 208, 216, 218, 219, 220, 221, 222, 223, 225, 226, 227, 228, 250, 253, 254, 255, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 271, 272, 273, 275, 277, 278, 284, 285, 287, 288, 289, 291, 292, 294, 297, 300, 303, 304, 305, 306, 307, 308], "unique_ptr": 1, "make_uniqu": 1, "to_stream": 1, "handl": [1, 216], "resolv": 1, "No": [1, 3], "happen": [1, 3, 261, 299, 305], "alon": [1, 306], "effect": [1, 222, 305], "onli": [1, 3, 5, 6, 60, 65, 66, 112, 141, 212, 216, 235, 236, 238, 243, 246, 247, 248, 297, 303, 308, 309], "execut": [1, 6, 306, 309], "depend": [1, 2, 57, 112, 304, 308, 309], "devic": [1, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 193, 194, 195, 309, 310], "specifi": [1, 15, 33, 66, 74, 91, 92, 99, 100, 112, 114, 131, 135, 144, 154, 179, 180, 181, 184, 185, 189, 192, 194, 218, 260, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 303, 309], "memori": [1, 5, 198, 261, 297, 305, 306], "ha": [1, 3, 4, 5, 57, 74, 90, 91, 93, 94, 95, 100, 144, 218, 229, 297, 299, 304, 305, 307, 309], "been": [1, 3, 305], "try": [1, 6], "naiv": [1, 303], "gener": [1, 2, 15, 83, 91, 92, 114, 143, 147, 148, 151, 152, 261, 300, 304, 305, 310], "version": [1, 6, 72, 120, 124, 141, 169, 192, 300, 303, 304], "declar": 1, "member": [1, 216, 241, 244], "method": [1, 3, 7, 8, 9, 26, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 216, 242, 297], "each": [1, 50, 72, 80, 125, 138, 141, 142, 144, 154, 161, 162, 171, 185, 192, 193, 222, 223, 224, 226, 254, 261, 275, 300, 305], "find": [1, 2, 6], "pointwis": 1, "captur": [1, 216], "templat": 1, "axpby_impl": 1, "typenam": 1, "t": [1, 3, 78, 142, 189, 197, 198, 199, 200, 201, 202, 203, 206, 207, 216, 303, 309], "readi": 1, "fill": [1, 99, 136, 186, 195, 262, 263, 264, 265, 266, 268, 269], "malloc_or_wait": 1, "synchron": 1, "avail": [1, 2, 3, 4, 6, 8, 212, 309], "There": [1, 216], "wait": [1, 3], "here": [1, 3, 303, 305, 308, 309], "request": 1, "pressur": 1, "condit": [1, 193, 309], "set_data": 1, "nbyte": 1, "collect": [1, 205, 209, 302], "pointer": 1, "x_ptr": 1, "y_ptr": 1, "out_ptr": 1, "relev": 1, "static_cast": 1, "size_t": 1, "out_idx": 1, "size": [1, 3, 4, 50, 66, 72, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 99, 103, 112, 141, 142, 144, 155, 171, 174, 198, 216, 219, 220, 224, 227, 251, 305, 306], "map": [1, 4, 115, 209, 224, 231], "linear": [1, 3, 4, 5, 209, 216, 225, 238, 251, 253, 255, 257, 270, 271, 272, 288, 289, 290, 294, 297], "indic": [1, 13, 22, 23, 24, 25, 100, 105, 106, 107, 108, 171, 180, 181, 189, 243, 245, 275, 304], "offset": [1, 3, 74], "x_offset": 1, "elem_to_loc": 1, "stride": [1, 65, 66, 219, 220, 254, 304], "y_offset": 1, "contigu": 1, "regularli": 1, "default": [1, 6, 12, 13, 14, 15, 22, 23, 24, 25, 60, 64, 65, 66, 72, 73, 74, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 100, 103, 112, 113, 114, 115, 124, 126, 128, 129, 135, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 154, 155, 156, 170, 171, 174, 175, 178, 184, 185, 186, 187, 188, 189, 190, 192, 194, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 212, 218, 219, 220, 227, 229, 231, 236, 238, 243, 246, 249, 250, 251, 254, 258, 259, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 297, 300, 302, 303, 306, 308, 310], "row": [1, 83, 103, 141, 186], "major": 1, "henc": [1, 141], "doesn": [1, 216], "addit": [1, 3, 11, 115, 218, 226, 228, 249, 252, 297, 303], "abl": [1, 141], "all": [1, 4, 6, 13, 24, 66, 80, 83, 86, 89, 92, 95, 125, 138, 139, 174, 204, 216, 231, 232, 236, 239, 240, 241, 244, 246, 249, 251, 258, 261, 294, 297, 300, 304, 305, 307, 310], "incom": 1, "accordingli": 1, "dispatch": 1, "float16": [1, 115, 212, 231, 305, 306], "bfloat16": [1, 306], "complex64": 1, "throw": 1, "error": [1, 6, 78, 79, 171, 225, 251, 270, 271, 272, 281, 282, 303, 306], "encount": [1, 303], "unexpect": [1, 15], "regist": [1, 4], "op": [1, 137, 236, 305], "assert": 1, "2": [1, 2, 3, 4, 66, 73, 74, 78, 85, 88, 90, 91, 92, 93, 94, 95, 96, 112, 113, 119, 125, 141, 150, 184, 186, 187, 188, 197, 199, 200, 201, 206, 212, 216, 220, 225, 252, 258, 262, 263, 264, 265, 266, 267, 268, 269, 271, 275, 276, 278, 284, 285, 294, 297, 303, 304, 305, 306, 307, 308, 309], "1": [1, 3, 4, 15, 24, 25, 65, 66, 73, 74, 84, 85, 87, 88, 90, 91, 92, 93, 94, 95, 96, 104, 112, 113, 125, 137, 139, 141, 144, 152, 165, 170, 180, 189, 197, 198, 199, 200, 201, 202, 203, 206, 207, 212, 216, 218, 219, 220, 221, 222, 223, 225, 226, 227, 228, 229, 250, 252, 254, 255, 258, 260, 263, 264, 265, 266, 267, 268, 269, 271, 272, 273, 274, 275, 276, 277, 278, 279, 281, 283, 284, 285, 289, 292, 294, 297, 299, 303, 304, 306, 307, 308, 309], "correct": [1, 6, 200, 201, 202, 304, 305], "els": [1, 3, 216, 236, 305], "float16_t": 1, "bfloat16_t": 1, "complex64_t": 1, "runtime_error": 1, "support": [1, 3, 5, 6, 13, 65, 66, 96, 113, 115, 125, 141, 303, 304, 306, 308], "have": [1, 3, 6, 13, 60, 91, 92, 94, 95, 125, 144, 203, 208, 249, 256, 302, 304, 305, 309], "rememb": 1, "3": [1, 3, 6, 96, 112, 113, 198, 203, 264, 266, 300, 304, 306, 307], "complic": 1, "keep": [1, 12, 14, 22, 23, 124, 126, 128, 129, 140, 178, 190, 216, 235, 303, 305], "mind": [1, 3], "half": [1, 15, 148, 152, 254, 305], "precis": [1, 3, 216, 225], "direct": [1, 3, 203, 233, 309], "fix": [1, 3, 6, 305], "possibl": [1, 3, 125, 171, 224, 304, 309], "due": 1, "transpos": [1, 3, 27, 142], "aren": 1, "guarante": 1, "fit": [1, 141, 309], "requir": [1, 3, 216, 305, 306], "column": [1, 83, 103, 141], "inplac": 1, "expect": [1, 3, 219, 220, 221, 222, 223, 258, 261, 276, 304], "answer": 1, "copi": [1, 3, 5, 139, 170, 306], "simpli": [1, 3, 6, 253, 288, 297, 303], "catlas_saxpbi": 1, "axpby_impl_acceler": 1, "first": [1, 2, 3, 4, 6, 74, 96, 100, 121, 123, 125, 139, 150, 179, 184, 189, 198, 200, 201, 202, 208, 216, 226, 274, 303, 306, 309], "mode": [1, 67, 234, 243, 245, 265, 266], "i": [1, 3, 109, 112, 201, 216, 219, 220, 222, 223, 236, 281, 303], "e": [1, 4, 6, 78, 109, 165, 199, 218, 219, 220, 222, 223, 226, 227, 228, 236, 252, 293, 299, 305, 310], "match": [1, 6, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 238, 275, 304, 306], "transposit": 1, "data_s": 1, "items": 1, "flag": [1, 306], "copy_inplac": 1, "copytyp": 1, "n": [1, 3, 26, 65, 66, 83, 84, 86, 87, 89, 90, 93, 95, 103, 186, 190, 218, 219, 220, 222, 223, 281, 285], "incx": 1, "inci": 1, "great": 1, "But": [1, 309], "criteria": 1, "luckili": [1, 305], "alwai": [1, 208, 303], "With": 1, "final": [1, 2, 3, 4], "singl": [1, 4, 80, 109, 115, 138, 191, 304, 308], "row_contigu": 1, "col_contigu": 1, "common": [1, 305], "hit": 1, "mileston": 1, "enough": [1, 305], "run": [1, 3, 4, 5, 6, 137, 197, 198, 200, 201, 202, 218, 231, 305, 309, 310], "If": [1, 3, 6, 12, 13, 14, 15, 22, 23, 24, 25, 57, 60, 63, 64, 67, 73, 74, 80, 93, 94, 95, 98, 99, 100, 112, 115, 124, 125, 126, 128, 129, 135, 138, 139, 140, 144, 154, 169, 170, 171, 178, 180, 181, 184, 189, 190, 192, 194, 198, 209, 218, 219, 220, 226, 228, 229, 236, 238, 246, 251, 254, 256, 258, 273, 275, 285, 303, 305, 308, 309, 310], "plan": 1, "stop": [1, 3, 15, 114, 176, 303, 304], "enjoi": 1, "speed": 1, "appl": [1, 3, 5, 6, 309], "silicon": [1, 3, 5, 6, 309], "address": 1, "shade": 1, "languag": [1, 212], "kernel": [1, 65, 66, 304], "written": 1, "help": [1, 3, 309], "resourc": 1, "walkthrough": 1, "pipelin": 1, "specif": [1, 6, 303], "cpp": 1, "algorithm": [1, 203], "launch": [1, 304], "exactli": [1, 3, 238, 303], "mani": [1, 171, 219, 220, 224, 305], "thread": 1, "pick": 1, "updat": [1, 2, 3, 4, 198, 201, 203, 207, 209, 218, 231, 238, 248, 299, 305], "assign": [1, 297], "axpby_gener": 1, "buffer": [1, 306], "constant": [1, 3, 6, 138, 206, 216, 218, 226, 228, 252, 276, 285, 306], "4": [1, 3, 72, 96, 112, 141, 142, 161, 212, 218, 227, 251, 261, 263, 264, 265, 273, 304, 307, 309], "5": [1, 2, 3, 6, 112, 143, 206, 218, 221, 222, 223, 227, 259, 262, 265, 266, 284, 291, 294, 303, 304], "x_stride": 1, "6": [1, 3, 112, 161, 206, 261, 264, 271, 272, 276, 285, 304, 307], "y_stride": 1, "7": [1, 3, 112, 141, 304], "ndim": [1, 96, 112], "8": [1, 3, 6, 112, 141, 197, 198, 199, 200, 201, 202, 206, 212, 227, 261, 274, 304, 307, 309], "uint": 1, "index": [1, 5, 7, 9, 24, 82, 83, 100, 139, 180, 181, 189], "thread_position_in_grid": 1, "convert": [1, 57, 96, 251, 305, 306, 307], "instanti": [1, 4, 305], "uniqu": [1, 300], "host": 1, "name": [1, 115, 141, 142, 159, 160, 161, 162, 205, 216, 226, 235, 238, 240, 304, 308], "identifi": [1, 208, 302], "instantiate_axpbi": 1, "type_nam": 1, "host_nam": 1, "axpby_general_": 1, "compil": [1, 6, 305], "mlx_ext": 1, "metallib": [1, 6], "see": [1, 3, 4, 6, 8, 28, 29, 30, 31, 32, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 112, 159, 160, 216, 218, 222, 225, 234, 250, 251, 254, 255, 258, 259, 263, 264, 265, 266, 270, 271, 272, 289, 303, 304, 307, 309], "later": [1, 6], "co": [1, 258, 303], "locat": [1, 247, 248, 309], "share": [1, 5, 72, 141, 142], "register_librari": 1, "potenti": 1, "path": [1, 6, 161, 162, 238], "tri": 1, "load": [1, 4, 5, 238], "hasn": 1, "alreadi": [1, 3], "static": [1, 6], "object": [1, 8, 26, 37, 57, 143, 148, 151, 152, 192, 208, 209, 222, 302], "why": [1, 3], "packag": [1, 2, 4, 294], "process": [1, 3, 67, 209, 223, 224, 261, 302], "logic": [1, 121, 122, 123], "grid": 1, "shown": 1, "below": [1, 6, 112, 186, 188, 212, 305], "prepar": [1, 3], "carri": 1, "should": [1, 2, 3, 4, 6, 74, 109, 141, 181, 189, 191, 208, 216, 219, 220, 222, 223, 243, 249, 256, 275, 277, 297, 302, 303, 305, 306, 310], "d": [1, 3, 73, 74, 104, 112, 125, 137, 180, 186, 187, 188, 197, 200, 202, 210, 223, 309], "ostringstream": 1, "kname": 1, "axpby_": 1, "general_": 1, "type_to_nam": 1, "make": [1, 3, 4, 6, 125, 216, 305, 307, 309], "sure": [1, 3, 6, 216], "look": [1, 3], "folder": 1, "get_colocated_mtllib_path": 1, "get_kernel": 1, "str": [1, 67, 100, 112, 115, 158, 159, 160, 161, 162, 189, 208, 210, 231, 232, 235, 236, 238, 240, 242, 246, 265, 266, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285], "encod": [1, 254, 258, 261, 275], "compute_encod": 1, "get_command_encod": 1, "setcomputepipelinest": 1, "those": [1, 3, 216], "nelem": 1, "set_array_buff": 1, "setbyt": 1, "sizeof": 1, "threadgroup": 1, "higher": [1, 104, 303], "than": [1, 3, 57, 67, 74, 76, 101, 102, 110, 111, 125, 198, 203, 209, 254, 260, 284, 292, 303, 309], "max": [1, 112, 127, 198, 202, 250, 274, 276, 277, 285, 287, 303, 309], "allow": [1, 204, 216, 248, 297, 304, 307], "tgp_size": 1, "min": [1, 112, 130, 250, 287], "maxtotalthreadsperthreadgroup": 1, "3d": [1, 218, 223], "mtl": 1, "group_dim": 1, "grid_dim": 1, "divid": [1, 98, 141], "among": 1, "dispatchthread": 1, "few": [1, 3, 4, 5, 305, 307], "thing": [1, 3], "note": [1, 3, 6, 13, 65, 66, 91, 92, 112, 141, 144, 216, 306, 308], "befor": [1, 3, 6, 24, 139, 235, 261, 304, 305], "move": [1, 131, 309], "track": [1, 216, 218], "activ": [1, 6, 222, 230, 260, 261, 286, 291, 292, 293], "command": [1, 6], "instead": [1, 6, 216, 248, 258, 303, 305], "end_encod": 1, "end": [1, 74, 141, 255, 260, 278, 284, 289, 291, 292], "until": [1, 305, 307], "limit": [1, 63, 304], "flush": 1, "enqueu": 1, "commit": 1, "associ": [1, 161, 162, 305], "suggest": 1, "deeper": 1, "dive": 1, "studi": 1, "come": [1, 3, 303], "far": [1, 299], "built": [1, 6, 305], "includ": [1, 232, 251, 276, 303, 304, 307, 308, 310], "forward": [1, 189, 305], "diff": 1, "push": 1, "along": [1, 22, 23, 64, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 112, 154, 169, 171, 175, 180, 181, 184, 216], "similarli": [1, 6, 125, 303, 305], "scale_arr": 1, "contribut": 1, "tangent_x": 1, "tangent_i": 1, "revers": [1, 185, 258], "arg": [1, 3, 8, 47, 58, 80, 161, 162], "push_back": 1, "fulli": [1, 5, 306, 309], "overal": 1, "directori": [1, 3, 6], "extens": [1, 115, 212, 242, 308], "h": [1, 65, 66, 112, 218, 220, 222, 223, 303, 305], "mlx_sample_extens": 1, "__init__": [1, 3, 4, 7, 8, 9, 26, 216, 297], "py": [1, 3, 6], "cmakelist": 1, "txt": 1, "setup": [1, 2, 4, 6], "hold": [1, 3, 8, 112, 204], "instal": 1, "pybind11": [1, 6], "sinc": [1, 3, 4, 203, 297, 306, 309], "compon": [1, 3], "etc": [1, 141, 216], "becom": 1, "pybind11_modul": 1, "m": [1, 6, 83, 112, 186, 197], "doc": [1, 4], "sampl": [1, 2, 3, 114, 143, 144, 145, 148, 151, 152, 263, 264, 265, 266, 268, 269, 276, 285, 300], "_a": 1, "pos_onli": 1, "kw_onli": 1, "none": [1, 3, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 192, 193, 194, 195, 198, 208, 209, 225, 231, 235, 236, 246, 249, 258, 261, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 304], "r": [1, 3, 113, 189, 222], "pbdoc": 1, "most": [1, 144, 216, 303, 304, 305], "complex": [1, 91, 92, 93, 94, 95, 143, 148, 151, 152, 208, 216, 248, 303], "bell": 1, "whistl": 1, "liter": [1, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285], "string": [1, 306, 308], "modul": [1, 3, 4, 196, 251, 256, 261, 294, 302, 305], "ensur": [1, 6, 281], "caster": 1, "find_packag": 1, "config": 1, "add_librari": 1, "sourc": [1, 131, 185], "target_sourc": 1, "cmake_current_list_dir": 1, "header": 1, "target_include_directori": 1, "target_link_librari": 1, "attach": 1, "conveni": [1, 4], "mlx_build_metallib": 1, "target": [1, 189, 273, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284], "destin": [1, 131], "automat": [1, 5, 115, 307, 308, 309], "practic": 1, "mlx_build_met": [1, 6], "mlx_ext_metallib": 1, "titl": 1, "include_dir": 1, "project_source_dir": 1, "mlx_include_dir": 1, "output_directori": 1, "cmake_library_output_directori": 1, "add_depend": 1, "endif": 1, "pybind11_add_modul": 1, "build_shared_lib": 1, "target_link_opt": 1, "wl": 1, "rpath": 1, "loader_path": 1, "onc": 1, "describ": [1, 305], "util": [1, 3, 5, 6, 161, 216], "__name__": [1, 3], "__main__": [1, 3], "descript": [1, 3, 212], "ext_modul": 1, "cmakeextens": 1, "cmdclass": 1, "build_ext": 1, "cmakebuild": 1, "package_dir": 1, "package_data": 1, "dylib": 1, "zip_saf": 1, "fals": [1, 3, 12, 13, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 59, 60, 112, 115, 124, 126, 128, 129, 140, 178, 190, 193, 198, 207, 208, 209, 212, 226, 227, 229, 236, 238, 246, 249, 251, 254, 258, 261, 273, 276, 306], "python_requir": 1, "even": [1, 3, 305, 306], "though": [1, 3, 305, 306], "j8": 1, "libmlx_ext": 1, "cpython": 1, "3x": 1, "darwin": 1, "pip": [1, 6], "after": [1, 3, 4, 24, 96, 98, 139, 141, 218, 226, 228, 249, 261, 284, 309], "plai": [1, 3], "ones": [1, 3, 136, 161, 186, 247, 248, 251, 304], "b": [1, 3, 11, 13, 60, 75, 76, 77, 98, 101, 102, 104, 110, 111, 112, 120, 121, 123, 125, 127, 130, 132, 137, 141, 177, 184, 189, 229, 303, 304, 305, 306, 307, 308, 309], "f": [1, 2, 4, 112, 201, 216, 306], "item": [1, 2, 3, 4, 209, 305, 306, 307], "true": [1, 2, 3, 13, 60, 112, 115, 142, 169, 193, 198, 208, 209, 212, 216, 218, 219, 220, 226, 227, 228, 229, 235, 236, 238, 243, 246, 251, 254, 258, 261, 273, 281], "quick": [1, 5], "benchmark": 1, "compar": [1, 60], "time": [1, 3, 6, 216, 303, 305, 309], "set_default_devic": 1, "256": [1, 4], "512": [1, 3, 261, 309], "random": [1, 2, 3, 4, 5, 218, 227, 238, 243, 303, 309, 310], "normal": [1, 2, 3, 151, 205, 216, 218, 226, 227, 228, 252, 261, 263, 265, 306, 309], "bench": 1, "warm": 1, "rang": [1, 2, 3, 4, 6, 15, 96, 114, 264, 266, 271, 272, 299, 300, 303, 305, 309], "100": [1, 2, 3, 303, 305, 309], "5000": 1, "simple_tim": 1, "custom_tim": 1, "3f": [1, 4], "custom": [1, 261], "114": 1, "109": 1, "modest": 1, "improv": [1, 3, 197, 198, 199, 200, 201, 202, 206], "awai": [1, 3], "good": [1, 6, 309], "nn": [1, 3, 4, 161, 209, 216, 294, 297, 299, 305], "grad": [1, 2, 4, 189, 299, 303, 304, 305, 307], "simplifi": [], "full": [1, 4, 47, 58, 67, 169, 247, 248, 276, 305], "implement": [2, 4, 112, 197, 198, 199, 200, 202, 203, 204, 205, 224, 235, 249, 254, 256, 258, 260, 261, 292, 303, 306], "basic": [2, 156, 303], "model": [2, 4, 5, 161, 196, 209, 216, 231, 234, 236, 238, 242, 243, 245, 246, 247, 249, 261, 294, 297, 299, 305], "problem": [2, 4, 216], "metadata": [2, 115, 159], "num_featur": [2, 218], "num_exampl": 2, "1_000": 2, "num_it": 2, "10_000": 2, "iter": [2, 4, 209, 300, 305], "sgd": [2, 4, 203, 299], "lr": [2, 203], "01": [2, 201], "rate": [2, 197, 198, 199, 200, 201, 202, 203, 206, 207], "ll": [2, 4, 278, 303], "synthet": 2, "dataset": [2, 305], "matrix": [2, 72, 73, 83, 103, 112, 113, 125, 141, 142, 251, 267, 294], "ground": [2, 3, 275, 284], "truth": [2, 275, 284], "w_star": 2, "valu": [2, 3, 10, 13, 15, 22, 23, 37, 57, 60, 63, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 99, 112, 114, 138, 143, 144, 145, 147, 148, 151, 152, 159, 180, 181, 189, 192, 196, 198, 201, 205, 208, 209, 212, 221, 222, 223, 227, 229, 235, 249, 250, 254, 259, 260, 261, 262, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 292, 297, 303], "gaussian": [2, 225, 270, 271, 272, 276], "nois": 2, "exampl": [2, 3, 4, 15, 96, 112, 113, 180, 216, 218, 227, 236, 238, 243, 246, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 294, 299, 300, 303, 304, 305, 306, 307, 308], "noisi": 2, "label": [2, 275], "ep": [2, 197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276, 285], "1e": [2, 4, 13, 197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276, 285], "us": [2, 3, 4, 5, 6, 15, 72, 76, 96, 112, 113, 125, 141, 142, 154, 155, 197, 198, 200, 201, 202, 203, 204, 208, 216, 222, 224, 225, 229, 231, 235, 242, 247, 248, 249, 251, 254, 258, 261, 265, 266, 271, 272, 274, 294, 297, 299, 300, 302, 303, 304, 307, 309], "weight": [2, 65, 66, 198, 201, 203, 207, 209, 216, 238, 242, 251, 275, 297, 303, 305], "squar": [2, 3, 103, 157, 172, 189, 197, 198, 200, 201, 202, 209, 216, 252, 282, 284, 303, 306], "loss": [2, 4, 189, 216, 299, 303, 305], "loss_fn": [2, 4, 299, 303], "w": [2, 66, 72, 141, 142, 189, 207, 218, 220, 222, 223, 229, 303], "mean": [2, 3, 4, 189, 216, 218, 226, 236, 252, 268, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 303, 306], "grad_fn": [2, 303], "initi": [2, 3, 216, 218, 226, 227, 228, 229, 250, 252, 262, 263, 264, 265, 266, 267, 268, 269, 297, 305], "randomli": [2, 3, 221, 222, 223], "Then": [2, 6], "repeatedli": 2, "_": [2, 3, 216, 300, 305, 309], "verifi": [2, 6], "close": [2, 5, 6, 13], "error_norm": 2, "5f": 2, "someth": [2, 3, 304], "00005": 2, "00364": 2, "complet": [2, 3, 6, 247, 248, 303, 309], "logist": [2, 165, 257, 271, 272, 290], "github": [2, 4, 6], "repo": [2, 4, 6], "enabl": [3, 6, 207], "larg": [3, 216, 249, 281, 305], "ish": 3, "transform": [3, 5, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 196, 216, 218, 226, 228, 229, 235, 236, 246, 251, 254, 304], "compromis": 3, "eas": 3, "llama": 3, "famili": 3, "less": [3, 24, 111, 139, 254, 284], "200": 3, "line": [3, 305, 306], "python": [3, 37, 50, 57, 80, 208, 209, 210, 297, 302, 303, 306], "neural": [3, 5, 206, 224, 230, 263, 264, 286, 294, 297], "network": [3, 5, 206, 218, 222, 224, 263, 264, 294, 297], "build": [3, 5, 265, 297], "concis": 3, "architectur": [3, 6, 216, 248, 309], "notabl": [3, 5], "rope": [3, 216], "posit": [3, 24, 74, 96, 100, 108, 131, 139, 189, 209, 216, 219, 220, 249, 254, 258, 276, 285], "option": [3, 12, 14, 15, 22, 23, 24, 25, 26, 31, 32, 64, 65, 66, 67, 72, 73, 74, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 99, 100, 103, 107, 108, 112, 113, 114, 115, 124, 126, 128, 129, 135, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 154, 155, 169, 170, 171, 174, 175, 178, 180, 181, 184, 185, 186, 187, 188, 189, 190, 192, 194, 197, 198, 199, 200, 201, 202, 203, 206, 207, 208, 209, 218, 219, 220, 229, 231, 235, 236, 238, 246, 249, 251, 254, 258, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 300, 308, 310], "kei": [3, 143, 144, 145, 147, 148, 150, 151, 152, 205, 208, 209, 235, 236, 246, 249, 254, 300, 302, 303], "cach": [3, 254], "concaten": 3, "project": [3, 249], "llamaattent": 3, "self": [3, 4, 7, 9, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 57, 58, 59, 216, 230, 286, 297], "dim": [3, 224, 226, 227, 228, 249, 252, 254, 258, 261], "num_head": [3, 249, 261], "super": [3, 4, 216, 297], "tradit": [3, 222, 223, 254], "query_proj": 3, "bia": [3, 72, 141, 142, 200, 201, 202, 209, 216, 219, 220, 229, 236, 238, 246, 249, 251, 303], "key_proj": 3, "value_proj": 3, "out_proj": [3, 297], "__call__": [3, 4, 216, 297], "queri": [3, 249], "mask": [3, 243, 249, 304], "extract": [3, 73, 74, 216, 235, 297], "l": [3, 4, 216, 218, 219, 284], "reshap": [3, 112, 304], "combin": 3, "key_cach": 3, "value_cach": 3, "sqrt": [3, 78, 197, 199, 200, 201, 206, 218, 226, 227, 228, 229, 252, 258, 263, 264, 265, 266], "score": 3, "softmax": [3, 275], "values_hat": 3, "rm": [3, 6, 198], "swiglu": 3, "rmsnorm": [3, 216], "llamaencoderlay": 3, "mlp_dim": [3, 261], "norm1": 3, "norm2": 3, "linear1": 3, "linear2": 3, "linear3": 3, "sigmoid": [3, 257, 271, 272, 290], "instanc": [3, 141, 210, 216, 227, 231, 232, 233, 236, 239, 240, 246, 248, 256, 297, 306], "embed": [3, 216, 254, 258, 274], "emb": [3, 224, 258], "token": [3, 224], "num_lay": [3, 4, 299], "vocab_s": 3, "norm": [3, 202, 203, 226, 285], "multiheadattent": [3, 216], "create_additive_causal_mask": 3, "list": [3, 8, 12, 14, 26, 29, 30, 40, 41, 42, 43, 45, 50, 53, 56, 57, 59, 61, 64, 80, 82, 85, 86, 88, 89, 91, 92, 94, 95, 99, 100, 109, 112, 124, 126, 128, 129, 135, 138, 140, 143, 144, 145, 147, 148, 151, 152, 155, 159, 169, 171, 174, 175, 178, 184, 185, 189, 190, 191, 194, 200, 201, 202, 203, 208, 210, 216, 236, 238, 239, 240, 241, 244, 246, 247, 248, 297, 302, 303, 305], "still": [3, 6, 112, 305], "consid": [3, 13, 60, 208, 209, 226, 302], "train": [3, 4, 216, 218, 221, 222, 223, 234, 236, 246, 263, 264], "ignor": [3, 63, 198], "whatsoev": 3, "rest": [3, 209, 254], "subsect": 3, "prompt": 3, "autoregress": 3, "yield": [3, 4, 300], "temp": 3, "causal": 3, "save": [3, 5, 115, 141, 159, 160, 161, 162, 242, 305], "append": [3, 125, 305], "store": 3, "per": [3, 4, 72, 141, 142, 204, 218, 226, 227, 228, 252, 305], "care": [3, 305], "last": [3, 25, 57, 86, 89, 91, 92, 94, 95, 96, 104, 113, 125, 144, 170, 184, 219, 220, 222, 223, 226, 306], "logit": [3, 144, 273, 275], "next": [3, 4], "categor": 3, "lazili": [3, 216], "noth": [3, 216, 305], "yet": [3, 112, 216, 297, 303, 304, 305, 307], "forc": [3, 4, 216, 307], "choos": [3, 254], "pars": 3, "feed": 3, "loop": [3, 4, 303, 305], "unsqueez": 3, "sequenc": [3, 218, 219, 261, 300, 309], "length": [3, 174, 218, 219], "len": [3, 86, 89, 92, 95], "overwrit": 3, "discard": [3, 208], "old": 3, "moment": [3, 198, 200, 201, 202], "anymor": 3, "everyth": 3, "small": [3, 218, 226, 228, 252, 276, 281, 285, 309], "10": [3, 4, 117, 156, 161, 209, 216, 238, 294, 304], "12": 3, "8192": 3, "1024": 3, "actual": [3, 15, 238, 297, 305], "materi": [3, 5], "could": [3, 216], "20_000": 3, "machin": [3, 5, 6, 206], "8gb": 3, "ram": 3, "32": [3, 4, 141, 142, 212], "44": 3, "doubl": 3, "bracket": 3, "becaus": [3, 216, 305], "batch": [3, 125, 218, 219, 220, 222, 223, 249, 305], "zip": [3, 4], "haven": 3, "anyth": [3, 189, 305], "result": [3, 15, 57, 72, 104, 112, 115, 125, 137, 142, 154, 156, 175, 184, 193, 209, 258, 303, 306], "similar": [3, 209, 247, 248, 249, 274, 306, 308], "runtim": 3, "section": [3, 6, 171, 285, 303], "access": [3, 37, 216, 297, 305, 309], "origin": [3, 74, 197, 198, 199, 200, 202, 203, 218, 263, 264, 265, 266, 306], "sentencepiec": 3, "pytorch": [3, 5, 226, 303], "compat": [3, 144, 308], "npz": [3, 115, 161, 162, 238, 242, 308], "file": [3, 6, 115, 158, 159, 160, 161, 162, 238, 242, 303, 308], "directli": 3, "argpars": 3, "itertool": [3, 209], "starmap": [3, 209], "np": [3, 4, 306, 307], "torch": [3, 306], "map_torch_to_mlx": 3, "tok_embed": 3, "elif": 3, "replac": [3, 247, 248, 261, 284], "attention_norm": 3, "ffn_norm": 3, "wq": 3, "wk": 3, "wv": 3, "wo": 3, "w1": 3, "w2": 3, "w3": 3, "ffn": 3, "separ": [3, 47, 58, 226], "submodul": [3, 4, 216, 236, 237, 246, 248], "feed_forward": 3, "parser": 3, "argumentpars": 3, "add_argu": 3, "torch_weight": 3, "output_fil": 3, "parse_arg": 3, "state": [3, 4, 204, 205, 216, 299, 300], "savez": [3, 242, 308], "k": [3, 73, 83, 186, 187, 188, 229, 236], "v": [3, 67, 216, 236, 306], "left": [3, 112, 141, 225, 254, 271, 272, 276, 278, 285], "disk": 3, "text": [3, 198, 203, 230, 255, 260, 263, 264, 265, 266, 276, 277, 278, 281, 284, 286, 287, 289, 291, 292], "format": [3, 115, 158, 159, 160, 161, 162, 306], "oper": [3, 5, 33, 169, 176, 181, 203, 216, 261, 303, 304, 305, 306, 307, 309, 310], "dictionari": [3, 115, 159, 160, 204, 205, 208, 216, 235, 247, 248, 302, 308], "represent": [3, 141, 208, 210], "tree_unflatten": 3, "helper": 3, "weight_fil": 3, "incur": 3, "sever": [3, 65, 66, 161, 162, 308], "futur": [3, 251, 304, 305], "pth": 3, "current": [3, 5, 6, 65, 66, 141, 198, 216, 305], "around": 3, "m1": [3, 303, 309], "ultra": 3, "7b": 3, "me": 3, "ishmael": 3, "year": 3, "ago": 3, "never": [3, 305], "long": 3, "info": [3, 6], "247": 3, "press": [3, 112], "enter": 3, "littl": 3, "monei": 3, "my": [3, 6], "purs": 3, "greater": [3, 24, 102, 139, 260, 292], "consequ": 3, "walk": 3, "down": 3, "gower": 3, "street": 3, "afternoon": 3, "heavi": 3, "rain": 3, "saw": [3, 303], "off": [3, 6, 305], "man": 3, "rag": 3, "who": 3, "sat": 3, "upon": [3, 209], "hi": 3, "bundl": 3, "hard": 3, "wet": 3, "he": [3, 265, 266], "were": [3, 309], "cry": 3, "watch": 3, "him": 3, "observ": 3, "numer": [3, 112, 120, 124, 169, 197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276, 285, 305], "crowd": 3, "wa": [3, 205, 305], "hurri": 3, "437": 3, "330": 3, "second": [3, 74, 121, 123, 125, 179, 189, 198, 200, 201, 202, 274, 303, 309], "spent": 3, "amount": 3, "39": 3, "ms": 3, "By": [3, 303, 306], "bigger": [3, 198], "remain": [3, 189, 221, 222, 223], "almost": 3, "nobodi": 3, "took": 3, "least": [3, 63, 113, 141], "notic": [3, 303, 308], "distanc": [3, 285], "had": 3, "doubt": 3, "minut": 3, "straight": 3, "slowli": 3, "rais": [3, 112, 171, 238], "ey": 3, "speak": [3, 112], "resum": 3, "postur": 3, "stood": 3, "feel": 3, "pain": 3, "heart": 3, "smile": 3, "face": 3, "am": 3, "someon": 3, "three": 3, "quarter": 3, "hour": 3, "made": 3, "immedi": [3, 231], "repli": 3, "again": [3, 6, 216], "hand": [3, 303, 305], "did": 3, "accustom": 3, "thu": [3, 216], "question": [3, 305], "reason": [3, 304], "tell": [3, 306], "understand": [3, 263, 264], "579": 3, "690": 3, "num": [3, 114, 150], "500": [3, 309], "628": 3, "went": 3, "nervou": 3, "trembl": 3, "told": 3, "And": 3, "perhap": 3, "surpris": 3, "matter": [3, 216], "shall": 3, "anyhow": 3, "friend": 3, "ye": 3, "slight": [3, 305], "kind": 3, "longer": [3, 67, 303], "soon": 3, "unless": [3, 13, 112, 297], "unlik": [3, 13, 222, 223], "strang": 3, "amus": 3, "That": 3, "secret": 3, "disappoint": 3, "mine": 3, "cannot": [3, 63, 304, 306], "happi": 3, "ask": 3, "Is": [3, 258, 261], "shop": 3, "bui": 3, "food": 3, "633": 3, "21": 3, "475": 3, "su": 3, "j": [3, 6, 112, 199, 200, 202, 222], "lu": 3, "pan": 3, "murtadha": 3, "wen": 3, "liu": 3, "2021": 3, "roform": [3, 254], "enhanc": [3, 254, 305], "rotari": [3, 254], "arxiv": [3, 197, 203, 226, 227, 228, 230, 252, 286], "preprint": [3, 197, 203], "2104": 3, "09864": 3, "zhang": 3, "sennrich": 3, "2019": [3, 201], "root": [3, 157, 172, 252], "advanc": 3, "inform": [3, 4, 6, 159, 160, 216, 218, 225, 249, 303, 309], "system": [3, 6], "shazeer": 3, "2020": 3, "glu": 3, "variant": [3, 202, 284], "2002": 3, "05202": 3, "classifi": 4, "mnist": 4, "As": [4, 180, 216], "mlp": [4, 216, 261, 299], "inherit": [4, 302], "standard": [4, 37, 57, 125, 145, 261, 263, 265, 268, 307], "idiom": 4, "input_dim": [4, 216, 229, 251], "hidden_dim": [4, 297, 299], "output_dim": [4, 216, 229, 251], "layer_s": 4, "idim": 4, "odim": 4, "maximum": [4, 22, 63, 216, 253, 258, 271, 272, 288, 297, 305], "cross": [4, 273, 275], "entropi": [4, 273, 275], "sub": [4, 74, 150], "commonli": [4, 247, 294], "cross_entropi": [4, 216], "accuraci": 4, "valid": [4, 67, 96, 192, 208, 236, 246, 302], "eval_fn": 4, "argmax": 4, "loader": 4, "num_class": [4, 299], "batch_siz": [4, 299], "num_epoch": [4, 299], "learning_r": [4, 197, 198, 199, 200, 201, 202, 203, 206, 207, 299], "train_imag": [4, 299], "train_label": [4, 299], "test_imag": 4, "test_label": 4, "shuffl": 4, "minibatch": 4, "batch_iter": [4, 299], "perm": 4, "permut": 4, "id": [4, 6], "put": 4, "trainabl": [4, 196, 216, 297], "loss_and_grad_fn": [4, 299, 303], "value_and_grad": [4, 216, 247, 297, 299, 303, 306, 307], "epoch": 4, "test": [4, 6], "confus": 4, "decent": 4, "95": 4, "brought": 5, "research": 5, "except": [5, 83, 90, 91, 93, 94, 95, 226, 238, 304, 306], "featur": [5, 65, 66, 218, 226, 227, 228, 229, 251, 252, 254, 261, 305], "main": [5, 74, 83, 209, 216], "differ": [5, 177, 284, 303], "lazi": [5, 297, 307], "multi": [5, 219, 220, 304, 306], "cpu": [5, 113, 309], "gpu": [5, 304, 309], "inspir": 5, "jax": [5, 300], "arrayfir": 5, "unifi": 5, "live": [5, 309], "guid": 5, "convers": 5, "regress": [5, 281], "layer": [5, 216, 222, 223, 226, 228, 229, 243, 248, 251, 256, 261, 293, 297], "perceptron": 5, "llm": 5, "infer": [5, 99, 115], "fft": 5, "algebra": 5, "tree": [5, 80, 100, 189, 192, 204, 208, 209, 210, 303], "develop": [5, 6], "document": [5, 47, 58, 159, 160, 303, 304], "meet": 6, "seri": 6, "chip": 6, "nativ": 6, "maco": 6, "13": 6, "recommend": [6, 203], "14": 6, "sonoma": 6, "distribut": [6, 143, 144, 145, 147, 151, 152, 229, 263, 264, 265, 266, 268, 269, 276, 279, 283, 285, 294], "probabl": [6, 148, 221, 222, 223, 251, 273, 275, 279, 309], "platform": 6, "processor": 6, "arm": [6, 212], "i386": 6, "switch": 6, "conda": 6, "17": 6, "g": [6, 112, 141, 206, 207, 293, 305, 310], "clang": 6, "cmake": 6, "24": 6, "xcode": 6, "15": [6, 112], "environ": 6, "via": [6, 305, 306], "rosetta": 6, "unam": 6, "p": [6, 143, 200, 202, 216, 221, 222, 223, 285], "clone": 6, "git": 6, "com": 6, "ml": 6, "explor": 6, "cd": 6, "brew": 6, "global": [6, 149, 300], "env": 6, "cmake_build_parallel_level": 6, "edit": [6, 248], "unittest": 6, "discov": 6, "stub": 6, "dev": 6, "generate_stub": 6, "mkdir": 6, "either": [6, 11, 47, 57, 58, 63, 75, 76, 77, 98, 101, 102, 110, 111, 112, 120, 125, 127, 130, 132, 177, 189, 256, 265, 266], "libmlx": 6, "preprocessor": 6, "metal_path": 6, "mlx_build_test": 6, "ON": 6, "mlx_build_exampl": 6, "mlx_build_benchmark": 6, "mlx_build_python_bind": 6, "multipl": [6, 125, 132, 141, 142, 249, 258, 305, 308], "wish": 6, "variabl": [6, 100, 109, 189, 191, 192], "export": 6, "developer_dir": 6, "app": 6, "content": [6, 235], "sdk": 6, "xcrun": 6, "macosx": 6, "show": [6, 212], "unabl": 6, "tool": 6, "select": [6, 193, 231, 235], "sudo": 6, "ouptut": 6, "finder": 6, "iterm": 6, "termin": 6, "click": 6, "uncheck": 6, "window": 6, "restart": 6, "devicetyp": 7, "attribut": [7, 8, 9, 26, 297], "kwarg": [8, 161, 162, 310], "union": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 159, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 193, 194, 195, 220, 236, 238, 246], "absolut": [10, 13, 271, 272, 284], "semant": [11, 61, 75, 76, 77, 101, 102, 110, 111, 120, 125, 127, 130, 132, 177, 309], "keepdim": [12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 59, 112, 124, 126, 128, 129, 140, 169, 178, 190], "reduct": [12, 14, 124, 126, 129, 140, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285], "reduc": [12, 14, 22, 23, 124, 126, 128, 129, 140, 178, 190, 218, 261, 281], "unspecifi": [12, 14, 15, 22, 23, 24, 25, 64, 99, 124, 126, 128, 129, 135, 139, 140, 154, 169, 170, 178, 180, 190, 194, 310], "entir": [12, 14, 22, 23, 124, 126, 128, 129, 140, 178, 190, 222, 223], "singleton": [12, 14, 22, 23, 124, 125, 126, 128, 129, 140, 178, 190], "rtol": 13, "05": [13, 218, 226, 227, 228, 252], "atol": 13, "08": [13, 199, 200, 201, 202, 206, 274], "approxim": [13, 225, 270, 271, 272], "comparison": [13, 77, 101, 102, 110, 111], "equal": [13, 24, 60, 83, 102, 111, 139, 148, 171, 227, 229], "ab": [13, 112, 189, 226, 227, 228, 230, 252, 286], "array_equ": 13, "rel": [13, 198], "toler": 13, "boolean": [13, 60, 105, 106, 107, 108, 121, 122, 123, 212, 245, 304], "interv": [15, 114, 148, 152], "increment": 15, "otherwis": [15, 208, 209, 236, 238, 246, 260, 261, 273, 278, 284, 291, 292, 305, 306], "int32": [15, 96, 112, 148, 212, 304, 307], "convent": [15, 67, 201], "lead": 15, "fraction": 15, "integr": [15, 180, 305], "invers": [16, 17, 18, 19, 20, 21, 79, 87, 88, 89, 90, 91, 92], "cosin": [16, 17, 68, 69, 254, 274, 303], "hyperbol": [17, 19, 21, 69, 168, 183], "sine": [18, 19, 167, 168, 254, 303], "minimum": [23, 63, 258, 274], "kth": [24, 139], "partit": 24, "order": [24, 112, 139, 141, 216, 226, 247, 256, 303], "undefin": [24, 139, 304], "sort": [24, 25, 139], "flatten": [24, 25, 112, 137, 139, 154, 170, 180, 181, 208], "dimension": [26, 84, 85, 86, 87, 88, 89, 93, 94, 95, 218, 219, 220, 224, 229, 251, 258, 304, 306], "val": [26, 99], "tupl": [26, 47, 50, 58, 64, 66, 76, 80, 82, 109, 112, 113, 138, 141, 155, 174, 189, 191, 198, 200, 201, 202, 203, 208, 209, 210, 220, 238, 240, 254, 256, 302, 303], "ndarrai": [26, 304, 305, 307], "properti": [27, 35, 44, 50, 52, 245, 303], "argument": [27, 47, 58, 80, 100, 189, 209, 216, 300, 303, 308, 309, 310], "decim": [48, 156], "indices_or_sect": [53, 171], "nest": [57, 216, 297, 302, 303], "ddof": [59, 190], "equal_nan": [13, 60], "nan": [13, 60, 106], "a_min": 63, "a_max": 63, "edg": [63, 138], "At": 63, "anoth": [63, 125, 177, 193, 216, 231, 303, 304, 309], "pad": [65, 66, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 219, 220], "dilat": [65, 66], "group": [65, 66, 72, 141, 142, 226, 251], "1d": [65, 67, 159, 181], "convolut": [65, 66, 67, 219, 220, 222, 223], "channel": [65, 66, 218, 219, 220, 222, 223], "c_in": [65, 66], "c_out": [65, 66], "convolv": [65, 66], "2d": [66, 74, 141, 218, 222], "spatial": [66, 226], "symmetr": 66, "discret": [67, 84, 85, 86, 87, 88, 89, 93, 94, 95, 224], "swap": [67, 179, 248, 251], "conv": 67, "filter": [67, 219, 220, 231, 235], "flip": 67, "signal": 67, "bias": [72, 141, 142, 236, 246, 249], "group_siz": [72, 141, 142, 251], "64": [72, 141, 142, 212, 251], "configur": 72, "formal": [72, 141], "notat": [72, 208, 240], "quantiz": [72, 115, 142, 251], "w_i": [72, 141], "hat": [72, 141], "occupi": [72, 141, 142], "divis": [75, 98, 141], "quotient": [75, 76, 98], "remaind": 76, "fuction": 76, "faster": [76, 270, 303], "mathrm": [78, 165, 227], "frac": [78, 141, 165, 197, 199, 200, 201, 202, 206, 218, 221, 222, 223, 226, 227, 228, 229, 252, 263, 264, 265, 266, 274, 276, 278, 281], "pi": [78, 258, 303], "int_0": 78, "dx": [], "erf": 79, "node": [80, 192], "dict": [80, 115, 159, 160, 161, 241, 244, 247, 248, 297, 302, 303, 308], "leaf": [80, 208, 209, 235], "exponenti": [81, 255, 289], "insert": [74, 82, 309], "ident": [83, 176, 216, 243], "diagon": [73, 83, 186, 187, 188], "zero": [83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 186, 187, 188, 195, 198, 216, 221, 222, 223, 238, 262, 263, 264, 265, 266, 267, 268, 269, 294, 304], "th": [73, 83], "whose": [83, 196], "One": [84, 87, 93, 157, 303], "fourier": [84, 85, 86, 87, 88, 89, 93, 94, 95], "truncat": [84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 151], "dft": [84, 85, 86, 87, 88, 89, 93, 94, 95], "rfft": 90, "real": [90, 91, 92, 93, 94, 95], "rfft2": 91, "rfftn": 92, "silent": [93, 94, 95], "start_axi": 96, "end_axi": 96, "integ": [98, 112, 138, 141, 142, 143, 148, 171, 184, 192, 212, 224, 304], "floor": 98, "fun": [100, 109, 189, 191, 192, 304, 305, 309], "argnam": [100, 189], "cpp_function": [], "neither": [100, 189], "keyword": [100, 161, 162, 189, 209, 216, 300, 308, 310], "strict": [101, 110, 236, 238, 246], "ordinari": 104, "ord": 112, "tabl": [112, 212, 224], "frobeniu": 112, "matric": [112, 113], "strictli": 112, "mathemat": 112, "variou": 112, "purpos": 112, "calcul": [112, 198, 276], "fro": 112, "inf": [112, 249], "largest": 112, "sing": 112, "smallest": 112, "singular": 112, "nuclear": 112, "_f": 112, "sum_": [112, 281], "a_": 112, "valueerror": [112, 238, 303], "refer": [112, 227, 230, 263, 264, 265, 266, 286, 304], "golub": 112, "van": 112, "loan": 112, "baltimor": 112, "md": 112, "john": 112, "hopkin": 112, "univers": 112, "1985": 112, "pg": 112, "la": 112, "arang": [112, 304, 306], "9": [112, 197, 200, 201, 202, 203, 275, 306], "74597": 112, "20": 112, "84804": 112, "41421": 112, "23607": [112, 113], "74166": 112, "24264": 112, "11": 112, "225": 112, "50": 114, "evenli": 114, "binari": [115, 158, 159, 160, 161, 162, 260, 273, 292], "npy": [115, 158, 308], "safetensor": [115, 160, 238, 242, 305, 308], "gguf": [115, 159, 308], "unsupport": 115, "tensor": [115, 184, 285, 306], "natur": [116, 118, 305], "logarithm": [116, 117, 118, 119], "log": [118, 120, 124, 276, 279, 281, 283], "plu": 118, "exp": [120, 124, 145, 169, 255, 279, 289, 309], "stabl": [120, 124, 169, 281], "prepend": 125, "remov": [74, 125, 144, 174, 275], "negat": 133, "beforehand": 137, "pad_with": 138, "constant_valu": 138, "pad_width": 138, "before_1": 138, "after_1": 138, "before_2": 138, "after_2": 138, "before_n": 138, "after_n": 138, "before_i": 138, "after_i": 138, "extend": 138, "side": 138, "smaller": [139, 203], "everi": [141, 209, 303], "particular": [141, 226], "consecut": [141, 254], "w_1": 141, "w_g": 141, "begin": [141, 255, 260, 278, 284, 289, 291, 292], "align": 141, "max_i": 141, "min_i": 141, "textrm": [141, 225, 270], "round": 141, "pack": [141, 142], "unsign": [141, 142, 212], "lower": [141, 148, 151, 152, 186, 269], "upper": [141, 148, 151, 152, 269], "1st": 141, "signific": 141, "2nd": 141, "dequant": 141, "w_q": 141, "whether": [142, 235, 249, 273, 276], "prng": [143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 300], "num_sampl": 144, "unnorm": [144, 273, 275], "draw": 144, "uint32": [22, 23, 24, 25, 144, 212], "cdf": [145, 225, 270], "accord": [145, 193, 249, 263, 264, 265, 266], "seed": 146, "low": [148, 152, 269, 294], "high": [148, 152, 216, 224, 269, 294], "bound": [148, 151, 152, 225, 269, 304, 309], "roadcast": 148, "domain": 151, "uniformli": 152, "repetit": 154, "preserv": [155, 303], "reciproc": 157, "arr": [158, 304], "uncompress": 161, "my_path": 161, "tree_flatten": [161, 209, 210, 216], "transformerencod": 161, "128": [161, 216], "flat_param": 161, "compress": 162, "simplif": [], "reus": [], "consumpt": [], "meant": [], "overhead": [305, 309], "1m": [], "thousand": 305, "foo": [], "matmul": 309, "twice": 309, "subarrai": [74, 171], "being": [176, 216], "prevent": [176, 285, 306], "flow": [176, 305], "unchang": [176, 254], "axis1": [74, 179], "axis2": [74, 179], "taken": [74, 180], "prior": [180, 181], "exclud": 181, "dot": [184, 208, 240, 249], "elsewher": [186, 304], "col": 186, "triangl": 186, "mse": 189, "param": [189, 216, 294, 303], "lvalu": 189, "dlvalu": 189, "dparam": 189, "lasso": 189, "l1": [189, 278, 280, 281, 284], "varianc": [190, 218, 226, 276], "divisor": 190, "cotang": 191, "in_ax": [192, 303], "out_ax": [192, 303], "prefix": [192, 208], "fn": [196, 209, 307], "callabl": [196, 208, 209, 231, 232, 235, 256, 261, 262, 263, 264, 265, 266, 267, 268, 269], "wrt": 196, "rho": 197, "06": [197, 276, 285], "paper": [197, 198, 199, 200, 202, 203, 218, 258], "zeiler": 197, "2012": [197, 206], "adapt": [197, 198, 199], "1212": 197, "5701": 197, "v_": [197, 199, 200, 201, 202, 206, 207], "v_t": [197, 199, 200, 201, 202, 206, 207], "g_t": [197, 199, 200, 201, 202, 203, 206, 207], "delta": [197, 278], "w_": [197, 198, 199, 200, 201, 202, 203, 206, 207], "u_t": 197, "epsilon": [197, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276], "u_": 197, "w_t": [197, 199, 200, 201, 202, 203, 206, 207], "lambda": [197, 198, 199, 200, 201, 202, 203, 206, 207, 209, 216, 231, 236, 255, 259, 289, 291, 303], "averag": [197, 198, 200, 201, 202], "denomin": [197, 199, 200, 201, 202, 206, 227, 274], "stabil": [197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276], "duchi": 199, "hazan": 199, "singer": 199, "2011": 199, "subgradi": 199, "onlin": 199, "stochast": [199, 200, 202, 207, 305], "jmlr": 199, "999": [200, 201, 202], "omit": [200, 202], "estim": [200, 202], "kingma": [200, 202], "ba": [200, 202], "2015": [200, 202, 222], "iclr": [200, 201, 202], "m_": [200, 201, 202, 203], "beta_1": [198, 200, 201, 202, 203], "m_t": [200, 201, 202, 203], "beta_2": [200, 201, 202, 203], "weight_decai": [198, 201, 203, 207], "contrast": [201, 205], "loshchilov": 201, "hutter": 201, "decoupl": 201, "decai": [198, 201, 203, 207], "regular": [201, 222, 230, 286, 304], "adam": [202, 203], "infin": [105, 107, 108, 202], "99": [203, 206], "sign": [13, 203, 212], "tend": 203, "larger": [203, 254], "10x": 203, "adamw": 203, "maintain": [203, 222, 223], "strength": [203, 207], "wd": 203, "chen": 203, "symbol": 203, "discoveri": 203, "2302": 203, "06675": 203, "c_": 203, "eta": 203, "c_t": 203, "momentum": [203, 207, 218], "basi": 204, "appli": [204, 209, 216, 218, 219, 220, 222, 223, 225, 226, 227, 228, 229, 230, 232, 243, 250, 251, 252, 253, 255, 257, 259, 260, 270, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 294], "optimizerst": 204, "recurs": [205, 216, 235, 236, 241, 244, 246, 297], "defaultdict": 205, "miss": [205, 238, 308], "present": 205, "tieleman": 206, "hinton": 206, "lectur": 206, "coursera": 206, "smooth": [206, 275, 284], "dampen": 207, "nesterov": 207, "descent": [207, 305], "mu": 207, "tau": 207, "l2": [207, 278, 281], "penalti": 207, "is_leaf": [208, 209], "arbitrari": [208, 297], "depth": [208, 223, 303], "hello": [208, 210], "charact": 208, "flat": [208, 210], "superset": 209, "extra": 209, "closer": 209, "constitut": 209, "dict_kei": 209, "recreat": 210, "world": 210, "42": 210, "byte": 212, "bool_": 212, "uint8": 212, "uint16": 212, "16": [212, 227, 231, 297], "uint64": 212, "int8": 212, "int16": 212, "int64": 212, "arbitrarili": [216, 302, 303, 307], "done": [216, 221, 305, 306], "manual": 216, "explicitli": [216, 300], "solv": 216, "intuit": 216, "freez": [216, 246, 297], "finetun": 216, "in_dim": [216, 297], "out_dim": [216, 297], "enumer": 216, "caus": [216, 305], "local": [216, 222], "scope": 216, "l2_loss": 216, "y_hat": 216, "trainable_paramet": [216, 235], "loss_and_grad": 216, "workhors": 216, "Its": 216, "frozen": [216, 236, 244, 246, 251, 297], "individu": [216, 222, 223], "subset": [216, 235], "action": 216, "displai": 216, "tree_map": 216, "count": 216, "num_param": 216, "preclud": 216, "pure": [216, 299], "pattern": [216, 305], "achiev": 216, "other_input": 216, "necessari": 216, "wrap": 216, "apply_to_modul": [216, 236], "children": 216, "filter_and_map": 216, "leaf_modul": 216, "load_weight": [216, 305], "named_modul": 216, "save_weight": 216, "unfreez": [216, 236], "update_modul": 216, "sequenti": [216, 294], "relu": [216, 250, 261, 287, 294], "prelu": 216, "gelu": [216, 271, 272], "silu": 216, "selu": 216, "mish": 216, "quantizedlinear": 216, "conv1d": 216, "conv2d": 216, "batchnorm": 216, "layernorm": 216, "groupnorm": 216, "instancenorm": 216, "dropout": [216, 222, 223, 243, 261], "dropout2d": 216, "dropout3d": 216, "alibi": 216, "sinusoidalpositionalencod": 216, "gelu_approx": [216, 225, 270], "gelu_fast_approx": [216, 225, 270], "binary_cross_entropi": 216, "kl_div_loss": 216, "l1_loss": 216, "mse_loss": 216, "nll_loss": 216, "smooth_l1_loss": 216, "triplet_loss": 216, "hinge_loss": 216, "huber_loss": 216, "log_cosh_loss": 216, "cosine_similarity_loss": 216, "affin": [218, 226, 227, 228, 229, 251], "track_running_stat": 218, "var": [218, 226, 227, 228, 276], "gamma": [218, 226, 227, 228, 252, 263, 264, 265, 266], "nc": 218, "nlc": [218, 219], "four": 218, "nhwc": [218, 220], "height": [218, 220, 222, 223], "width": [218, 220, 222, 223, 251], "deep": [218, 263, 264, 265, 266], "intern": 218, "covari": 218, "shift": 218, "bn": 218, "in_channel": [219, 220], "out_channel": [219, 220], "kernel_s": [219, 220], "learnabl": [219, 220, 256], "portion": 221, "dure": [221, 222, 223, 306], "independ": [222, 223], "nwhc": 222, "whc": 222, "entri": [222, 223], "benefici": [222, 223, 305], "earli": 222, "adjac": 222, "pixel": 222, "correl": 222, "thompson": 222, "goroshin": 222, "jain": 222, "lecun": 222, "bregler": 222, "cvpr": 222, "ndhwc": 223, "dhwc": 223, "medic": 223, "video": 223, "num_embed": 224, "lookup": 224, "typic": [224, 299, 305], "usual": [224, 302, 305], "vocabulari": 224, "approx": 225, "unit": [225, 253, 255, 257, 263, 264, 265, 266, 270, 271, 272, 288, 289, 290], "phi": [225, 270], "geluapprox": 225, "sigma": [225, 257, 263, 264, 265, 266, 271, 272, 290], "60033": [225, 271], "0433603": [225, 271], "gelufast": 225, "773": [225, 272], "regard": 225, "num_group": 226, "pytorch_compat": 226, "split": 226, "preced": 226, "http": [226, 227, 228, 230, 252, 286], "org": [226, 227, 228, 230, 252, 286], "1803": 226, "08494": 226, "inorm": 227, "1607": [227, 228], "08022": 227, "06450": 228, "uniform": [216, 229, 238, 264, 266, 294, 300, 303, 309], "mathcal": 229, "u": 229, "d_i": 229, "monoton": [230, 286], "1908": [230, 286], "08681": [230, 286], "tanh": [230, 286], "softplu": [230, 286], "map_fn": [231, 235], "filter_fn": [231, 235], "valid_parameter_filt": 231, "apply_fn": 232, "descend": 233, "is_leaf_fn": 235, "found": 235, "drop": 235, "idempot": [236, 246], "attent": [236, 249, 258, 261], "endswith": 236, "file_or_weight": 238, "ok": [238, 303], "certain": 243, "ie": 246, "noop": 246, "unfrozen": 246, "chang": [247, 251, 278, 284, 306], "tracer": 247, "partial": [247, 248, 305], "child": 248, "programmat": 248, "query_input_dim": 249, "key_input_dim": 249, "value_input_dim": 249, "value_dim": 249, "value_output_dim": 249, "head": [249, 261], "aggreg": 249, "linearli": 249, "neg": [74, 96, 107, 249, 276, 283, 285, 304], "attend": 249, "num_paramet": 250, "init": [216, 250, 294], "25": 250, "parametr": [250, 287], "classmethod": 251, "from_linear": 251, "quantize_modul": 251, "1910": 252, "07467": 252, "rectifi": [253, 265, 266, 288], "10000": 254, "rotat": 254, "slightli": [254, 309], "angular": 254, "frequenc": [254, 258], "_cos_sin_theta_kei": 254, "precomput": 254, "_cos_sin_theta_valu": 254, "leq": [255, 278, 289], "0507": [255, 289], "67326": [255, 289], "elu": [255, 289], "plain": 256, "known": [257, 290], "swish": [257, 290], "cdot": [257, 271, 272, 274, 277, 290], "min_freq": 258, "0001": 258, "max_freq": 258, "cos_first": 258, "full_turn": 258, "sinusoid": 258, "sin": [258, 303, 307], "threshold": [260, 278, 284, 292], "geq": [260, 292], "num_encoder_lay": 261, "num_decoder_lay": 261, "custom_encod": 261, "custom_decod": 261, "norm_first": 261, "decod": 261, "interact": 261, "mechan": 261, "hidden": 261, "exact": [271, 272], "0003": 271, "015": 272, "pre": [], "predict": [273, 276, 277, 278, 279, 280, 281, 282, 283, 284], "105361": 273, "223144": 273, "20397": 273, "916291": 273, "612192": [], "x1": 274, "x2": 274, "x_1": 274, "x_2": 274, "label_smooth": 275, "hing": 277, "y_": [277, 281], "pred": [277, 281], "huber": 278, "l_": 278, "kullback": 279, "leibler": 279, "diverg": 279, "cosh": 281, "logcosh": 281, "sensit": 281, "outlier": 281, "dual": 281, "behavior": [281, 304, 305], "offer": 281, "balanc": 281, "robust": 281, "approach": [281, 303], "task": 281, "likelihood": [276, 283], "nll": [276, 283], "formula": 284, "anchor": 285, "margin": 285, "triplet": 285, "_p": 285, "degre": 285, "pairwis": 285, "instabl": 285, "subclass": 297, "concept": 297, "mymlp": 297, "in_proj": 297, "subsequ": 299, "implicit": [300, 303], "fine": [300, 305], "grain": 300, "control": [300, 305], "manag": [300, 309], "pseudo": 300, "altern": 300, "splittabl": 300, "threefri": 300, "counter": 300, "cycl": 302, "slice": 304, "ellipsi": 304, "syntax": 304, "idx": 304, "mix": 304, "take_along_axi": 304, "lack": 304, "propag": [303, 304], "extrem": [304, 305], "ineffici": [304, 305], "nonzero": 304, "reflect": [304, 306], "dfdx": [303, 304], "record": 305, "nice": [303, 305], "rerun": 305, "dynam": 305, "easier": 305, "worri": 305, "fun1": 305, "expensive_fun": 305, "cost": [198, 305], "code": 305, "consum": 305, "eager": 305, "thank": 305, "weights_fp16": 305, "trade": 305, "too": 305, "bad": 305, "idea": [303, 305], "On": [303, 305], "grow": 305, "computation": 305, "costli": 305, "wide": 305, "pretti": 305, "ten": [303, 305], "okai": 305, "outer": 305, "value_and_grad_fn": 305, "awar": 305, "implicitli": 305, "anytim": 305, "memoryview": [305, 306], "perfectli": 305, "first_lay": 305, "second_layer_a": 305, "second_layer_b": 305, "frequent": 305, "protocol": 306, "receiv": 306, "pep": 306, "3118": 306, "view": 306, "a_view": 306, "owndata": 306, "quit": [303, 306], "power": [303, 306], "extern": 306, "x_view": 306, "modifi": 306, "df": 306, "x\u00b2": 306, "2x": 306, "indirectli": 306, "modif": 306, "seen": 306, "occur": 306, "incorpor": 306, "issu": [303, 306], "incorrect": 306, "experiment": 306, "break": 306, "advis": 306, "intermedi": 306, "jnp": 306, "tf": 306, "inspect": 307, "page": 307, "composit": 307, "archiv": 308, "savez_compress": 308, "save_safetensor": [242, 308], "save_gguf": 308, "arr_0": 308, "pool": 309, "advantag": 309, "don": 309, "parallel": 309, "race": 309, "interest": 309, "albeit": 309, "contriv": [303, 309], "suppos": [303, 309], "d1": 309, "d2": 309, "4096": [303, 309], "dens": 309, "better": [303, 309], "millisecond": 309, "measur": 309, "default_stream": 310, "default_devic": 310, "my_devic": 310, "pypi": 6, "forg": 6, "grep": 6, "cmake_host_system_processor": 6, "arm64": 6, "x86_64": 6, "wipe": 6, "cahc": 6, "rf": 6, "inifn": 105, "behind": 303, "d2fdx2": 303, "differentiaion": 303, "backward": 303, "zero_grad": 303, "detach": 303, "requires_grad": 303, "dloss_dw": 303, "dloss_dx": 303, "lot": 303, "redund": 303, "stop_gradi": 303, "autom": 303, "sake": 303, "clariti": 303, "difficult": 303, "primit": 303, "priorit": 303, "xs": 303, "ys": 303, "naive_add": 303, "vmap_add": 303, "timeit": 303, "total": 303, "390": 303, "wherea": 303, "025": 303, "Of": 303, "handi": 303, "infinit": 13, "dt": 78, "inclus": 96, "outsid": 96, "clamp": 96, "factorizatoin": 113, "q": 113, "894427": 113, "447214": 113, "57771": 113, "return_metadata": 115, "matadata": 115, "obj": 159, "30": 198, "001": 198, "clip_threshold": 198, "decay_r": 198, "scale_paramet": 198, "relative_step": 198, "warmup_init": 198, "sublinear": 198, "epsilon_1": 198, "epsilon_2": 198, "parameter_scal": 198, "clip": 198, "unscal": 198, "softshrink": 216, "gaussian_nll_loss": 216, "glorot_norm": 216, "glorot_uniform": 216, "he_norm": 216, "he_uniform": 216, "lambd": [259, 291], "checkpoint": 261, "chekpoint": 261, "usag": 261, "expens": 261, "init_fn": [262, 263, 264, 265, 266, 267, 268, 269, 294], "glorot": [263, 264], "deviat": [263, 265, 268], "fan_in": [263, 264, 265, 266], "fan_out": [263, 264, 265, 266], "difficulti": [263, 264], "feedforward": [263, 264], "191107": 263, "61278": 263, "150594": 263, "363207": 263, "gain": [263, 264, 265, 266], "89613": 263, "53947": 263, "48095": 263, "995016": 263, "223404": 264, "890597": 264, "379159": 264, "776856": 264, "90041": 264, "02264": 264, "912766": 264, "12451": 264, "fan": [265, 266], "delv": [265, 266], "surpass": [265, 266], "human": [265, 266], "level": [265, 266], "imagenet": [265, 266], "classif": [265, 266], "25211": 265, "458835": 265, "177208": 265, "0137595": 265, "6967": 265, "02765": 265, "15268": 265, "75787": 265, "kaim": 266, "0300242": 266, "0184009": 266, "793615": 266, "666329": 266, "64331": 266, "16506": 266, "08619": 266, "79854": 266, "982273": 268, "534422": 268, "380709": 268, "0645099": 268, "883935": 269, "863726": 269, "617261": 269, "417497": 269, "with_logit": 273, "539245": 273, "prob": 273, "510826": 273, "hot": 275, "0485873": 275, "348587": 275}, "objects": {"mlx.core": [[7, 0, 1, "", "Device"], [8, 0, 1, "", "Dtype"], [9, 0, 1, "", "Stream"], [10, 2, 1, "", "abs"], [11, 2, 1, "", "add"], [12, 2, 1, "", "all"], [13, 2, 1, "", "allclose"], [14, 2, 1, "", "any"], [15, 2, 1, "", "arange"], [16, 2, 1, "", "arccos"], [17, 2, 1, "", "arccosh"], [18, 2, 1, "", "arcsin"], [19, 2, 1, "", "arcsinh"], [20, 2, 1, "", "arctan"], [21, 2, 1, "", "arctanh"], [22, 2, 1, "", "argmax"], [23, 2, 1, "", "argmin"], [24, 2, 1, "", "argpartition"], [25, 2, 1, "", "argsort"], [26, 0, 1, "", "array"], [60, 2, 1, "", "array_equal"], [61, 2, 1, "", "broadcast_to"], [62, 2, 1, "", "ceil"], [63, 2, 1, "", "clip"], [64, 2, 1, "", "concatenate"], [65, 2, 1, "", "conv1d"], [66, 2, 1, "", "conv2d"], [67, 2, 1, "", "convolve"], [68, 2, 1, "", "cos"], [69, 2, 1, "", "cosh"], [70, 2, 1, "", "default_device"], [71, 2, 1, "", "default_stream"], [72, 2, 1, "", "dequantize"], [73, 2, 1, "", "diag"], [74, 2, 1, "", "diagonal"], [75, 2, 1, "", "divide"], [76, 2, 1, "", "divmod"], [77, 2, 1, "", "equal"], [78, 2, 1, "", "erf"], [79, 2, 1, "", "erfinv"], [80, 2, 1, "", "eval"], [81, 2, 1, "", "exp"], [82, 2, 1, "", "expand_dims"], [83, 2, 1, "", "eye"], [96, 2, 1, "", "flatten"], [97, 2, 1, "", "floor"], [98, 2, 1, "", "floor_divide"], [99, 2, 1, "", "full"], [100, 2, 1, "", "grad"], [101, 2, 1, "", "greater"], [102, 2, 1, "", "greater_equal"], [103, 2, 1, "", "identity"], [104, 2, 1, "", "inner"], [105, 2, 1, "", "isinf"], [106, 2, 1, "", "isnan"], [107, 2, 1, "", "isneginf"], [108, 2, 1, "", "isposinf"], [109, 2, 1, "", "jvp"], [110, 2, 1, "", "less"], [111, 2, 1, "", "less_equal"], [114, 2, 1, "", "linspace"], [115, 2, 1, "", "load"], [116, 2, 1, "", "log"], [117, 2, 1, "", "log10"], [118, 2, 1, "", "log1p"], [119, 2, 1, "", "log2"], [120, 2, 1, "", "logaddexp"], [121, 2, 1, "", "logical_and"], [122, 2, 1, "", "logical_not"], [123, 2, 1, "", "logical_or"], [124, 2, 1, "", "logsumexp"], [125, 2, 1, "", "matmul"], [126, 2, 1, "", "max"], [127, 2, 1, "", "maximum"], [128, 2, 1, "", "mean"], [129, 2, 1, "", "min"], [130, 2, 1, "", "minimum"], [131, 2, 1, "", "moveaxis"], [132, 2, 1, "", "multiply"], [133, 2, 1, "", "negative"], [134, 2, 1, "", "new_stream"], [135, 2, 1, "", "ones"], [136, 2, 1, "", "ones_like"], [137, 2, 1, "", "outer"], [138, 2, 1, "", "pad"], [139, 2, 1, "", "partition"], [140, 2, 1, "", "prod"], [141, 2, 1, "", "quantize"], [142, 2, 1, "", "quantized_matmul"], [153, 2, 1, "", "reciprocal"], [154, 2, 1, "", "repeat"], [155, 2, 1, "", "reshape"], [156, 2, 1, "", "round"], [157, 2, 1, "", "rsqrt"], [158, 2, 1, "", "save"], [159, 2, 1, "", "save_gguf"], [160, 2, 1, "", "save_safetensors"], [161, 2, 1, "", "savez"], [162, 2, 1, "", "savez_compressed"], [163, 2, 1, "", "set_default_device"], [164, 2, 1, "", "set_default_stream"], [165, 2, 1, "", "sigmoid"], [166, 2, 1, "", "sign"], [167, 2, 1, "", "sin"], [168, 2, 1, "", "sinh"], [169, 2, 1, "", "softmax"], [170, 2, 1, "", "sort"], [171, 2, 1, "", "split"], [172, 2, 1, "", "sqrt"], [173, 2, 1, "", "square"], [174, 2, 1, "", "squeeze"], [175, 2, 1, "", "stack"], [176, 2, 1, "", "stop_gradient"], [177, 2, 1, "", "subtract"], [178, 2, 1, "", "sum"], [179, 2, 1, "", "swapaxes"], [180, 2, 1, "", "take"], [181, 2, 1, "", "take_along_axis"], [182, 2, 1, "", "tan"], [183, 2, 1, "", "tanh"], [184, 2, 1, "", "tensordot"], [185, 2, 1, "", "transpose"], [186, 2, 1, "", "tri"], [187, 2, 1, "", "tril"], [188, 2, 1, "", "triu"], [189, 2, 1, "", "value_and_grad"], [190, 2, 1, "", "var"], [191, 2, 1, "", "vjp"], [192, 2, 1, "", "vmap"], [193, 2, 1, "", "where"], [194, 2, 1, "", "zeros"], [195, 2, 1, "", "zeros_like"]], "mlx.core.Device": [[7, 1, 1, "", "__init__"]], "mlx.core.Dtype": [[8, 1, 1, "", "__init__"]], "mlx.core.Stream": [[9, 1, 1, "", "__init__"]], "mlx.core.array": [[27, 3, 1, "", "T"], [26, 1, 1, "", "__init__"], [28, 1, 1, "", "abs"], [29, 1, 1, "", "all"], [30, 1, 1, "", "any"], [31, 1, 1, "", "argmax"], [32, 1, 1, "", "argmin"], [33, 1, 1, "", "astype"], [34, 1, 1, "", "cos"], [35, 3, 1, "", "dtype"], [36, 1, 1, "", "exp"], [37, 1, 1, "", "item"], [38, 1, 1, "", "log"], [39, 1, 1, "", "log1p"], [40, 1, 1, "", "logsumexp"], [41, 1, 1, "", "max"], [42, 1, 1, "", "mean"], [43, 1, 1, "", "min"], [44, 3, 1, "", "ndim"], [45, 1, 1, "", "prod"], [46, 1, 1, "", "reciprocal"], [47, 1, 1, "", "reshape"], [48, 1, 1, "", "round"], [49, 1, 1, "", "rsqrt"], [50, 3, 1, "", "shape"], [51, 1, 1, "", "sin"], [52, 3, 1, "", "size"], [53, 1, 1, "", "split"], [54, 1, 1, "", "sqrt"], [55, 1, 1, "", "square"], [56, 1, 1, "", "sum"], [57, 1, 1, "", "tolist"], [58, 1, 1, "", "transpose"], [59, 1, 1, "", "var"]], "mlx.core.fft": [[84, 2, 1, "", "fft"], [85, 2, 1, "", "fft2"], [86, 2, 1, "", "fftn"], [87, 2, 1, "", "ifft"], [88, 2, 1, "", "ifft2"], [89, 2, 1, "", "ifftn"], [90, 2, 1, "", "irfft"], [91, 2, 1, "", "irfft2"], [92, 2, 1, "", "irfftn"], [93, 2, 1, "", "rfft"], [94, 2, 1, "", "rfft2"], [95, 2, 1, "", "rfftn"]], "mlx.core.linalg": [[112, 2, 1, "", "norm"], [113, 2, 1, "", "qr"]], "mlx.core.random": [[143, 2, 1, "", "bernoulli"], [144, 2, 1, "", "categorical"], [145, 2, 1, "", "gumbel"], [146, 2, 1, "", "key"], [147, 2, 1, "", "normal"], [148, 2, 1, "", "randint"], [149, 2, 1, "", "seed"], [150, 2, 1, "", "split"], [151, 2, 1, "", "truncated_normal"], [152, 2, 1, "", "uniform"]], "mlx.nn": [[217, 0, 1, "", "ALiBi"], [218, 0, 1, "", "BatchNorm"], [219, 0, 1, "", "Conv1d"], [220, 0, 1, "", "Conv2d"], [221, 0, 1, "", "Dropout"], [222, 0, 1, "", "Dropout2d"], [223, 0, 1, "", "Dropout3d"], [224, 0, 1, "", "Embedding"], [225, 0, 1, "", "GELU"], [226, 0, 1, "", "GroupNorm"], [227, 0, 1, "", "InstanceNorm"], [228, 0, 1, "", "LayerNorm"], [229, 0, 1, "", "Linear"], [230, 0, 1, "", "Mish"], [297, 0, 1, "", "Module"], [249, 0, 1, "", "MultiHeadAttention"], [250, 0, 1, "", "PReLU"], [251, 0, 1, "", "QuantizedLinear"], [252, 0, 1, "", "RMSNorm"], [253, 0, 1, "", "ReLU"], [254, 0, 1, "", "RoPE"], [255, 0, 1, "", "SELU"], [256, 0, 1, "", "Sequential"], [257, 0, 1, "", "SiLU"], [258, 0, 1, "", "SinusoidalPositionalEncoding"], [259, 0, 1, "", "Softshrink"], [260, 0, 1, "", "Step"], [261, 0, 1, "", "Transformer"], [270, 0, 1, "", "gelu"], [271, 0, 1, "", "gelu_approx"], [272, 0, 1, "", "gelu_fast_approx"], [286, 0, 1, "", "mish"], [287, 0, 1, "", "prelu"], [288, 0, 1, "", "relu"], [289, 0, 1, "", "selu"], [290, 0, 1, "", "silu"], [291, 0, 1, "", "softshrink"], [292, 0, 1, "", "step"], [196, 2, 1, "", "value_and_grad"]], "mlx.nn.Module": [[231, 1, 1, "", "apply"], [232, 1, 1, "", "apply_to_modules"], [233, 1, 1, "", "children"], [234, 1, 1, "", "eval"], [235, 1, 1, "", "filter_and_map"], [236, 1, 1, "", "freeze"], [237, 1, 1, "", "leaf_modules"], [238, 1, 1, "", "load_weights"], [239, 1, 1, "", "modules"], [240, 1, 1, "", "named_modules"], [241, 1, 1, "", "parameters"], [242, 1, 1, "", "save_weights"], [243, 1, 1, "", "train"], [244, 1, 1, "", "trainable_parameters"], [245, 3, 1, "", "training"], [246, 1, 1, "", "unfreeze"], [247, 1, 1, "", "update"], [248, 1, 1, "", "update_modules"]], "mlx.nn.RoPE": [[254, 4, 1, "", "_cos_sin_theta_key"], [254, 4, 1, "", "_cos_sin_theta_value"]], "mlx.nn.init": [[262, 2, 1, "", "constant"], [263, 2, 1, "", "glorot_normal"], [264, 2, 1, "", "glorot_uniform"], [265, 2, 1, "", "he_normal"], [266, 2, 1, "", "he_uniform"], [267, 2, 1, "", "identity"], [268, 2, 1, "", "normal"], [269, 2, 1, "", "uniform"]], "mlx.nn.losses": [[273, 0, 1, "", "binary_cross_entropy"], [274, 0, 1, "", "cosine_similarity_loss"], [275, 0, 1, "", "cross_entropy"], [276, 0, 1, "", "gaussian_nll_loss"], [277, 0, 1, "", "hinge_loss"], [278, 0, 1, "", "huber_loss"], [279, 0, 1, "", "kl_div_loss"], [280, 0, 1, "", "l1_loss"], [281, 0, 1, "", "log_cosh_loss"], [282, 0, 1, "", "mse_loss"], [283, 0, 1, "", "nll_loss"], [284, 0, 1, "", "smooth_l1_loss"], [285, 0, 1, "", "triplet_loss"]], "mlx.optimizers": [[197, 0, 1, "", "AdaDelta"], [198, 0, 1, "", "Adafactor"], [199, 0, 1, "", "Adagrad"], [200, 0, 1, "", "Adam"], [201, 0, 1, "", "AdamW"], [202, 0, 1, "", "Adamax"], [203, 0, 1, "", "Lion"], [204, 0, 1, "", "Optimizer"], [205, 0, 1, "", "OptimizerState"], [206, 0, 1, "", "RMSprop"], [207, 0, 1, "", "SGD"]], "mlx.optimizers.Optimizer": [[204, 4, 1, "", "state"]], "mlx.utils": [[208, 2, 1, "", "tree_flatten"], [209, 2, 1, "", "tree_map"], [210, 2, 1, "", "tree_unflatten"]]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function", "3": "py:property", "4": "py:attribute"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"], "3": ["py", "property", "Python property"], "4": ["py", "attribute", "Python attribute"]}, "titleterms": {"oper": [0, 1, 298], "develop": 1, "document": 1, "introduc": 1, "exampl": [1, 5, 309], "primit": 1, "us": [1, 305, 310], "implement": [1, 3], "cpu": 1, "backend": 1, "gpu": 1, "transform": [1, 261, 301, 303, 305, 307], "build": [1, 6], "bind": 1, "python": [1, 5, 6], "cmake": 1, "setuptool": 1, "usag": [1, 5], "result": 1, "script": [1, 3], "download": [1, 3], "code": [1, 3], "linear": [2, 215, 229], "regress": 2, "llm": 3, "infer": 3, "model": 3, "attent": 3, "layer": [3, 4, 295], "encod": 3, "full": [3, 99], "gener": 3, "put": 3, "all": [3, 12, 29], "togeth": 3, "convert": 3, "weight": 3, "load": [3, 115, 308], "benchmark": 3, "multi": 4, "perceptron": 4, "mlx": [5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292], "instal": [5, 6], "api": [5, 6], "refer": 5, "c": [5, 6], "further": 5, "read": 5, "from": [6, 304], "pypi": [], "troubleshoot": 6, "sourc": 6, "requir": 6, "option": 6, "metal": 6, "found": 6, "x86": 6, "shell": 6, "core": [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195], "devic": [7, 213], "dtype": [8, 35], "stream": [9, 213, 310], "ab": [10, 28], "add": 11, "allclos": 13, "ani": [14, 30], "arang": 15, "arcco": 16, "arccosh": 17, "arcsin": 18, "arcsinh": 19, "arctan": 20, "arctanh": 21, "argmax": [22, 31], "argmin": [23, 32], "argpartit": 24, "argsort": 25, "arrai": [26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 211, 304, 308], "t": 27, "astyp": 33, "co": [34, 68], "exp": [36, 81], "item": 37, "log": [38, 116], "log1p": [39, 118], "logsumexp": [40, 124], "max": [41, 126], "mean": [42, 128], "min": [43, 129], "ndim": 44, "prod": [45, 140], "reciproc": [46, 153], "reshap": [47, 155], "round": [48, 156], "rsqrt": [49, 157], "shape": 50, "sin": [51, 167], "size": 52, "split": [53, 150, 171], "sqrt": [54, 172], "squar": [55, 173], "sum": [56, 178], "tolist": 57, "transpos": [58, 185], "var": [59, 190], "array_equ": 60, "broadcast_to": 61, "ceil": 62, "clip": 63, "concaten": 64, "conv1d": [65, 219], "conv2d": [66, 220], "convolv": 67, "cosh": 69, "default_devic": 70, "default_stream": 71, "dequant": 72, "divid": 75, "divmod": 76, "equal": 77, "erf": 78, "erfinv": 79, "eval": [80, 234], "expand_dim": 82, "ey": 83, "fft": [84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 214], "fft2": 85, "fftn": 86, "ifft": 87, "ifft2": 88, "ifftn": 89, "irfft": 90, "irfft2": 91, "irfftn": 92, "rfft": 93, "rfft2": 94, "rfftn": 95, "flatten": 96, "floor": 97, "floor_divid": 98, "grad": [100, 216], "greater": 101, "greater_equ": 102, "ident": [103, 267], "inner": 104, "jvp": 109, "less": 110, "less_equ": 111, "linalg": [112, 113], "norm": 112, "linspac": 114, "log10": 117, "log2": 119, "logaddexp": 120, "logical_and": 121, "logical_not": 122, "logical_or": 123, "matmul": 125, "maximum": 127, "minimum": 130, "moveaxi": 131, "multipli": 132, "neg": 133, "new_stream": 134, "ones": 135, "ones_lik": 136, "outer": 137, "pad": 138, "partit": 139, "quantiz": 141, "quantized_matmul": 142, "random": [143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 300], "bernoulli": 143, "categor": 144, "gumbel": 145, "kei": 146, "normal": [147, 268], "randint": 148, "seed": 149, "truncated_norm": 151, "uniform": [152, 269], "repeat": 154, "save": [158, 308], "save_gguf": 159, "save_safetensor": 160, "savez": 161, "savez_compress": 162, "set_default_devic": 163, "set_default_stream": 164, "sigmoid": 165, "sign": 166, "simplifi": [], "sinh": 168, "softmax": 169, "sort": 170, "squeez": 174, "stack": 175, "stop_gradi": 176, "subtract": 177, "swapax": 179, "take": 180, "take_along_axi": 181, "tan": 182, "tanh": 183, "tensordot": 184, "tri": 186, "tril": 187, "triu": 188, "value_and_grad": [189, 196], "vjp": 191, "vmap": 192, "where": 193, "zero": 194, "zeros_lik": 195, "nn": [196, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292], "optim": [197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 299], "adadelta": 197, "adagrad": 199, "adam": 200, "adamw": 201, "adamax": 202, "lion": 203, "optimizerst": 205, "rmsprop": 206, "sgd": 207, "util": [208, 209, 210, 302], "tree_flatten": 208, "tree_map": 209, "tree_unflatten": 210, "data": 212, "type": 212, "support": 212, "algebra": 215, "neural": 216, "network": 216, "quick": [216, 307], "start": [216, 307], "The": 216, "modul": [216, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 297], "class": 216, "paramet": [216, 241], "updat": [216, 247, 304], "inspect": 216, "valu": 216, "alibi": 217, "batchnorm": 218, "dropout": 221, "dropout2d": 222, "dropout3d": 223, "embed": 224, "gelu": [225, 270], "groupnorm": 226, "instancenorm": 227, "layernorm": 228, "mish": [230, 286], "appli": 231, "apply_to_modul": 232, "children": 233, "filter_and_map": 235, "freez": 236, "leaf_modul": 237, "load_weight": 238, "named_modul": 240, "save_weight": 242, "train": [243, 245], "trainable_paramet": 244, "unfreez": 246, "update_modul": 248, "multiheadattent": 249, "prelu": [250, 287], "quantizedlinear": 251, "rmsnorm": 252, "relu": [253, 288], "rope": 254, "selu": [255, 289], "sequenti": 256, "silu": [257, 290], "sinusoidalpositionalencod": 258, "step": [260, 292], "gelu_approx": 271, "gelu_fast_approx": 272, "loss": [273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 296], "binary_cross_entropi": 273, "cosine_similarity_loss": 274, "cross_entropi": 275, "hinge_loss": 277, "huber_loss": 278, "kl_div_loss": 279, "l1_loss": 280, "log_cosh_loss": 281, "mse_loss": 282, "nll_loss": 283, "smooth_l1_loss": 284, "triplet_loss": 285, "function": [293, 296, 303, 307], "tree": 302, "index": 304, "differ": 304, "numpi": [304, 306], "In": 304, "place": 304, "lazi": 305, "evalu": 305, "why": 305, "comput": 305, "graph": [305, 307], "onli": 305, "what": 305, "you": 305, "when": 305, "convers": 306, "other": 306, "framework": 306, "pytorch": 306, "jax": 306, "tensorflow": 306, "guid": 307, "basic": 307, "serial": 308, "format": 308, "unifi": 309, "memori": 309, "A": 309, "simpl": 309, "specifi": 310, "isinf": 105, "isnan": 106, "isneginf": 107, "isposinf": 108, "automat": 303, "differenti": 303, "vector": 303, "diag": 73, "diagon": 74, "qr": 113, "adafactor": 198, "softshrink": [259, 291], "init": [262, 263, 264, 265, 266, 267, 268, 269], "constant": 262, "glorot_norm": 263, "glorot_uniform": 264, "he_norm": 265, "he_uniform": 266, "gaussian_nll_loss": 276, "initi": 294}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}})