mlx/docs/build/html/searchindex.js
2025-06-04 01:01:49 +00:00

1 line
136 KiB
JavaScript

Search.setIndex({"docnames": ["cpp/ops", "dev/extensions", "examples/linear_regression", "examples/llama-inference", "examples/mlp", "index", "install", "python/_autosummary/mlx.core.Device", "python/_autosummary/mlx.core.Dtype", "python/_autosummary/mlx.core.abs", "python/_autosummary/mlx.core.add", "python/_autosummary/mlx.core.all", "python/_autosummary/mlx.core.allclose", "python/_autosummary/mlx.core.any", "python/_autosummary/mlx.core.arange", "python/_autosummary/mlx.core.arccos", "python/_autosummary/mlx.core.arccosh", "python/_autosummary/mlx.core.arcsin", "python/_autosummary/mlx.core.arcsinh", "python/_autosummary/mlx.core.arctan", "python/_autosummary/mlx.core.arctanh", "python/_autosummary/mlx.core.argmax", "python/_autosummary/mlx.core.argmin", "python/_autosummary/mlx.core.argpartition", "python/_autosummary/mlx.core.argsort", "python/_autosummary/mlx.core.array", "python/_autosummary/mlx.core.array.T", "python/_autosummary/mlx.core.array.abs", "python/_autosummary/mlx.core.array.all", "python/_autosummary/mlx.core.array.any", "python/_autosummary/mlx.core.array.argmax", "python/_autosummary/mlx.core.array.argmin", "python/_autosummary/mlx.core.array.astype", "python/_autosummary/mlx.core.array.cos", "python/_autosummary/mlx.core.array.dtype", "python/_autosummary/mlx.core.array.exp", "python/_autosummary/mlx.core.array.item", "python/_autosummary/mlx.core.array.log", "python/_autosummary/mlx.core.array.log1p", "python/_autosummary/mlx.core.array.logsumexp", "python/_autosummary/mlx.core.array.max", "python/_autosummary/mlx.core.array.mean", "python/_autosummary/mlx.core.array.min", "python/_autosummary/mlx.core.array.ndim", "python/_autosummary/mlx.core.array.prod", "python/_autosummary/mlx.core.array.reciprocal", "python/_autosummary/mlx.core.array.reshape", "python/_autosummary/mlx.core.array.round", "python/_autosummary/mlx.core.array.rsqrt", "python/_autosummary/mlx.core.array.shape", "python/_autosummary/mlx.core.array.sin", "python/_autosummary/mlx.core.array.size", "python/_autosummary/mlx.core.array.split", "python/_autosummary/mlx.core.array.sqrt", "python/_autosummary/mlx.core.array.square", "python/_autosummary/mlx.core.array.sum", "python/_autosummary/mlx.core.array.tolist", "python/_autosummary/mlx.core.array.transpose", "python/_autosummary/mlx.core.array.var", "python/_autosummary/mlx.core.array_equal", "python/_autosummary/mlx.core.broadcast_to", "python/_autosummary/mlx.core.ceil", "python/_autosummary/mlx.core.clip", "python/_autosummary/mlx.core.compile", "python/_autosummary/mlx.core.concatenate", "python/_autosummary/mlx.core.conv1d", "python/_autosummary/mlx.core.conv2d", "python/_autosummary/mlx.core.convolve", "python/_autosummary/mlx.core.cos", "python/_autosummary/mlx.core.cosh", "python/_autosummary/mlx.core.default_device", "python/_autosummary/mlx.core.default_stream", "python/_autosummary/mlx.core.dequantize", "python/_autosummary/mlx.core.diag", "python/_autosummary/mlx.core.diagonal", "python/_autosummary/mlx.core.disable_compile", "python/_autosummary/mlx.core.divide", "python/_autosummary/mlx.core.divmod", "python/_autosummary/mlx.core.enable_compile", "python/_autosummary/mlx.core.equal", "python/_autosummary/mlx.core.erf", "python/_autosummary/mlx.core.erfinv", "python/_autosummary/mlx.core.eval", "python/_autosummary/mlx.core.exp", "python/_autosummary/mlx.core.expand_dims", "python/_autosummary/mlx.core.eye", "python/_autosummary/mlx.core.fft.fft", "python/_autosummary/mlx.core.fft.fft2", "python/_autosummary/mlx.core.fft.fftn", "python/_autosummary/mlx.core.fft.ifft", "python/_autosummary/mlx.core.fft.ifft2", "python/_autosummary/mlx.core.fft.ifftn", "python/_autosummary/mlx.core.fft.irfft", "python/_autosummary/mlx.core.fft.irfft2", "python/_autosummary/mlx.core.fft.irfftn", "python/_autosummary/mlx.core.fft.rfft", "python/_autosummary/mlx.core.fft.rfft2", "python/_autosummary/mlx.core.fft.rfftn", "python/_autosummary/mlx.core.flatten", "python/_autosummary/mlx.core.floor", "python/_autosummary/mlx.core.floor_divide", "python/_autosummary/mlx.core.full", "python/_autosummary/mlx.core.grad", "python/_autosummary/mlx.core.greater", "python/_autosummary/mlx.core.greater_equal", "python/_autosummary/mlx.core.identity", "python/_autosummary/mlx.core.inner", "python/_autosummary/mlx.core.isinf", "python/_autosummary/mlx.core.isnan", "python/_autosummary/mlx.core.isneginf", "python/_autosummary/mlx.core.isposinf", "python/_autosummary/mlx.core.jvp", "python/_autosummary/mlx.core.less", "python/_autosummary/mlx.core.less_equal", "python/_autosummary/mlx.core.linalg.norm", "python/_autosummary/mlx.core.linalg.qr", "python/_autosummary/mlx.core.linspace", "python/_autosummary/mlx.core.load", "python/_autosummary/mlx.core.log", "python/_autosummary/mlx.core.log10", "python/_autosummary/mlx.core.log1p", "python/_autosummary/mlx.core.log2", "python/_autosummary/mlx.core.logaddexp", "python/_autosummary/mlx.core.logical_and", "python/_autosummary/mlx.core.logical_not", "python/_autosummary/mlx.core.logical_or", "python/_autosummary/mlx.core.logsumexp", "python/_autosummary/mlx.core.matmul", "python/_autosummary/mlx.core.max", "python/_autosummary/mlx.core.maximum", "python/_autosummary/mlx.core.mean", "python/_autosummary/mlx.core.min", "python/_autosummary/mlx.core.minimum", "python/_autosummary/mlx.core.moveaxis", "python/_autosummary/mlx.core.multiply", "python/_autosummary/mlx.core.negative", "python/_autosummary/mlx.core.new_stream", "python/_autosummary/mlx.core.ones", "python/_autosummary/mlx.core.ones_like", "python/_autosummary/mlx.core.outer", "python/_autosummary/mlx.core.pad", "python/_autosummary/mlx.core.partition", "python/_autosummary/mlx.core.prod", "python/_autosummary/mlx.core.quantize", "python/_autosummary/mlx.core.quantized_matmul", "python/_autosummary/mlx.core.random.bernoulli", "python/_autosummary/mlx.core.random.categorical", "python/_autosummary/mlx.core.random.gumbel", "python/_autosummary/mlx.core.random.key", "python/_autosummary/mlx.core.random.normal", "python/_autosummary/mlx.core.random.randint", "python/_autosummary/mlx.core.random.seed", "python/_autosummary/mlx.core.random.split", "python/_autosummary/mlx.core.random.truncated_normal", "python/_autosummary/mlx.core.random.uniform", "python/_autosummary/mlx.core.reciprocal", "python/_autosummary/mlx.core.repeat", "python/_autosummary/mlx.core.reshape", "python/_autosummary/mlx.core.round", "python/_autosummary/mlx.core.rsqrt", "python/_autosummary/mlx.core.save", "python/_autosummary/mlx.core.save_gguf", "python/_autosummary/mlx.core.save_safetensors", "python/_autosummary/mlx.core.savez", "python/_autosummary/mlx.core.savez_compressed", "python/_autosummary/mlx.core.set_default_device", "python/_autosummary/mlx.core.set_default_stream", "python/_autosummary/mlx.core.sigmoid", "python/_autosummary/mlx.core.sign", "python/_autosummary/mlx.core.sin", "python/_autosummary/mlx.core.sinh", "python/_autosummary/mlx.core.softmax", "python/_autosummary/mlx.core.sort", "python/_autosummary/mlx.core.split", "python/_autosummary/mlx.core.sqrt", "python/_autosummary/mlx.core.square", "python/_autosummary/mlx.core.squeeze", "python/_autosummary/mlx.core.stack", "python/_autosummary/mlx.core.stop_gradient", "python/_autosummary/mlx.core.stream", "python/_autosummary/mlx.core.subtract", "python/_autosummary/mlx.core.sum", "python/_autosummary/mlx.core.swapaxes", "python/_autosummary/mlx.core.take", "python/_autosummary/mlx.core.take_along_axis", "python/_autosummary/mlx.core.tan", "python/_autosummary/mlx.core.tanh", "python/_autosummary/mlx.core.tensordot", "python/_autosummary/mlx.core.transpose", "python/_autosummary/mlx.core.tri", "python/_autosummary/mlx.core.tril", "python/_autosummary/mlx.core.triu", "python/_autosummary/mlx.core.value_and_grad", "python/_autosummary/mlx.core.var", "python/_autosummary/mlx.core.vjp", "python/_autosummary/mlx.core.vmap", "python/_autosummary/mlx.core.where", "python/_autosummary/mlx.core.zeros", "python/_autosummary/mlx.core.zeros_like", "python/_autosummary/mlx.nn.value_and_grad", "python/_autosummary/mlx.utils.tree_flatten", "python/_autosummary/mlx.utils.tree_map", "python/_autosummary/mlx.utils.tree_unflatten", "python/_autosummary/stream_class", "python/array", "python/data_types", "python/devices_and_streams", "python/fft", "python/linalg", "python/nn", "python/nn/_autosummary/mlx.nn.ALiBi", "python/nn/_autosummary/mlx.nn.AvgPool1d", "python/nn/_autosummary/mlx.nn.AvgPool2d", "python/nn/_autosummary/mlx.nn.BatchNorm", "python/nn/_autosummary/mlx.nn.Conv1d", "python/nn/_autosummary/mlx.nn.Conv2d", "python/nn/_autosummary/mlx.nn.Dropout", "python/nn/_autosummary/mlx.nn.Dropout2d", "python/nn/_autosummary/mlx.nn.Dropout3d", "python/nn/_autosummary/mlx.nn.Embedding", "python/nn/_autosummary/mlx.nn.GELU", "python/nn/_autosummary/mlx.nn.GroupNorm", "python/nn/_autosummary/mlx.nn.InstanceNorm", "python/nn/_autosummary/mlx.nn.LayerNorm", "python/nn/_autosummary/mlx.nn.Linear", "python/nn/_autosummary/mlx.nn.MaxPool1d", "python/nn/_autosummary/mlx.nn.MaxPool2d", "python/nn/_autosummary/mlx.nn.Mish", "python/nn/_autosummary/mlx.nn.Module.apply", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules", "python/nn/_autosummary/mlx.nn.Module.children", "python/nn/_autosummary/mlx.nn.Module.eval", "python/nn/_autosummary/mlx.nn.Module.filter_and_map", "python/nn/_autosummary/mlx.nn.Module.freeze", "python/nn/_autosummary/mlx.nn.Module.leaf_modules", "python/nn/_autosummary/mlx.nn.Module.load_weights", "python/nn/_autosummary/mlx.nn.Module.modules", "python/nn/_autosummary/mlx.nn.Module.named_modules", "python/nn/_autosummary/mlx.nn.Module.parameters", "python/nn/_autosummary/mlx.nn.Module.save_weights", "python/nn/_autosummary/mlx.nn.Module.state", "python/nn/_autosummary/mlx.nn.Module.train", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters", "python/nn/_autosummary/mlx.nn.Module.training", "python/nn/_autosummary/mlx.nn.Module.unfreeze", "python/nn/_autosummary/mlx.nn.Module.update", "python/nn/_autosummary/mlx.nn.Module.update_modules", "python/nn/_autosummary/mlx.nn.MultiHeadAttention", "python/nn/_autosummary/mlx.nn.PReLU", "python/nn/_autosummary/mlx.nn.QuantizedLinear", "python/nn/_autosummary/mlx.nn.RMSNorm", "python/nn/_autosummary/mlx.nn.ReLU", "python/nn/_autosummary/mlx.nn.RoPE", "python/nn/_autosummary/mlx.nn.SELU", "python/nn/_autosummary/mlx.nn.Sequential", "python/nn/_autosummary/mlx.nn.SiLU", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding", "python/nn/_autosummary/mlx.nn.Softshrink", "python/nn/_autosummary/mlx.nn.Step", "python/nn/_autosummary/mlx.nn.Transformer", "python/nn/_autosummary/mlx.nn.init.constant", "python/nn/_autosummary/mlx.nn.init.glorot_normal", "python/nn/_autosummary/mlx.nn.init.glorot_uniform", "python/nn/_autosummary/mlx.nn.init.he_normal", "python/nn/_autosummary/mlx.nn.init.he_uniform", "python/nn/_autosummary/mlx.nn.init.identity", "python/nn/_autosummary/mlx.nn.init.normal", "python/nn/_autosummary/mlx.nn.init.uniform", "python/nn/_autosummary_functions/mlx.nn.gelu", "python/nn/_autosummary_functions/mlx.nn.gelu_approx", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss", "python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss", "python/nn/_autosummary_functions/mlx.nn.mish", "python/nn/_autosummary_functions/mlx.nn.prelu", "python/nn/_autosummary_functions/mlx.nn.relu", "python/nn/_autosummary_functions/mlx.nn.selu", "python/nn/_autosummary_functions/mlx.nn.silu", "python/nn/_autosummary_functions/mlx.nn.softshrink", "python/nn/_autosummary_functions/mlx.nn.step", "python/nn/functions", "python/nn/init", "python/nn/layers", "python/nn/losses", "python/nn/module", "python/ops", "python/optimizers", "python/optimizers/_autosummary/mlx.optimizers.AdaDelta", "python/optimizers/_autosummary/mlx.optimizers.Adafactor", "python/optimizers/_autosummary/mlx.optimizers.Adagrad", "python/optimizers/_autosummary/mlx.optimizers.Adam", "python/optimizers/_autosummary/mlx.optimizers.AdamW", "python/optimizers/_autosummary/mlx.optimizers.Adamax", "python/optimizers/_autosummary/mlx.optimizers.Lion", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.init", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.state", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.update", "python/optimizers/_autosummary/mlx.optimizers.RMSprop", "python/optimizers/_autosummary/mlx.optimizers.SGD", "python/optimizers/_autosummary/mlx.optimizers.cosine_decay", "python/optimizers/_autosummary/mlx.optimizers.exponential_decay", "python/optimizers/_autosummary/mlx.optimizers.step_decay", "python/optimizers/common_optimizers", "python/optimizers/optimizer", "python/optimizers/schedulers", "python/random", "python/transforms", "python/tree_utils", "usage/compile", "usage/function_transforms", "usage/indexing", "usage/lazy_evaluation", "usage/numpy", "usage/quick_start", "usage/saving_and_loading", "usage/unified_memory", "usage/using_streams"], "filenames": ["cpp/ops.rst", "dev/extensions.rst", "examples/linear_regression.rst", "examples/llama-inference.rst", "examples/mlp.rst", "index.rst", "install.rst", "python/_autosummary/mlx.core.Device.rst", "python/_autosummary/mlx.core.Dtype.rst", "python/_autosummary/mlx.core.abs.rst", "python/_autosummary/mlx.core.add.rst", "python/_autosummary/mlx.core.all.rst", "python/_autosummary/mlx.core.allclose.rst", "python/_autosummary/mlx.core.any.rst", "python/_autosummary/mlx.core.arange.rst", "python/_autosummary/mlx.core.arccos.rst", "python/_autosummary/mlx.core.arccosh.rst", "python/_autosummary/mlx.core.arcsin.rst", "python/_autosummary/mlx.core.arcsinh.rst", "python/_autosummary/mlx.core.arctan.rst", "python/_autosummary/mlx.core.arctanh.rst", "python/_autosummary/mlx.core.argmax.rst", "python/_autosummary/mlx.core.argmin.rst", "python/_autosummary/mlx.core.argpartition.rst", "python/_autosummary/mlx.core.argsort.rst", "python/_autosummary/mlx.core.array.rst", "python/_autosummary/mlx.core.array.T.rst", "python/_autosummary/mlx.core.array.abs.rst", "python/_autosummary/mlx.core.array.all.rst", "python/_autosummary/mlx.core.array.any.rst", "python/_autosummary/mlx.core.array.argmax.rst", "python/_autosummary/mlx.core.array.argmin.rst", "python/_autosummary/mlx.core.array.astype.rst", "python/_autosummary/mlx.core.array.cos.rst", "python/_autosummary/mlx.core.array.dtype.rst", "python/_autosummary/mlx.core.array.exp.rst", "python/_autosummary/mlx.core.array.item.rst", "python/_autosummary/mlx.core.array.log.rst", "python/_autosummary/mlx.core.array.log1p.rst", "python/_autosummary/mlx.core.array.logsumexp.rst", "python/_autosummary/mlx.core.array.max.rst", "python/_autosummary/mlx.core.array.mean.rst", "python/_autosummary/mlx.core.array.min.rst", "python/_autosummary/mlx.core.array.ndim.rst", "python/_autosummary/mlx.core.array.prod.rst", "python/_autosummary/mlx.core.array.reciprocal.rst", "python/_autosummary/mlx.core.array.reshape.rst", "python/_autosummary/mlx.core.array.round.rst", "python/_autosummary/mlx.core.array.rsqrt.rst", "python/_autosummary/mlx.core.array.shape.rst", "python/_autosummary/mlx.core.array.sin.rst", "python/_autosummary/mlx.core.array.size.rst", "python/_autosummary/mlx.core.array.split.rst", "python/_autosummary/mlx.core.array.sqrt.rst", "python/_autosummary/mlx.core.array.square.rst", "python/_autosummary/mlx.core.array.sum.rst", "python/_autosummary/mlx.core.array.tolist.rst", "python/_autosummary/mlx.core.array.transpose.rst", "python/_autosummary/mlx.core.array.var.rst", "python/_autosummary/mlx.core.array_equal.rst", "python/_autosummary/mlx.core.broadcast_to.rst", "python/_autosummary/mlx.core.ceil.rst", "python/_autosummary/mlx.core.clip.rst", "python/_autosummary/mlx.core.compile.rst", "python/_autosummary/mlx.core.concatenate.rst", "python/_autosummary/mlx.core.conv1d.rst", "python/_autosummary/mlx.core.conv2d.rst", "python/_autosummary/mlx.core.convolve.rst", "python/_autosummary/mlx.core.cos.rst", "python/_autosummary/mlx.core.cosh.rst", "python/_autosummary/mlx.core.default_device.rst", "python/_autosummary/mlx.core.default_stream.rst", "python/_autosummary/mlx.core.dequantize.rst", "python/_autosummary/mlx.core.diag.rst", "python/_autosummary/mlx.core.diagonal.rst", "python/_autosummary/mlx.core.disable_compile.rst", "python/_autosummary/mlx.core.divide.rst", "python/_autosummary/mlx.core.divmod.rst", "python/_autosummary/mlx.core.enable_compile.rst", "python/_autosummary/mlx.core.equal.rst", "python/_autosummary/mlx.core.erf.rst", "python/_autosummary/mlx.core.erfinv.rst", "python/_autosummary/mlx.core.eval.rst", "python/_autosummary/mlx.core.exp.rst", "python/_autosummary/mlx.core.expand_dims.rst", "python/_autosummary/mlx.core.eye.rst", "python/_autosummary/mlx.core.fft.fft.rst", "python/_autosummary/mlx.core.fft.fft2.rst", "python/_autosummary/mlx.core.fft.fftn.rst", "python/_autosummary/mlx.core.fft.ifft.rst", "python/_autosummary/mlx.core.fft.ifft2.rst", "python/_autosummary/mlx.core.fft.ifftn.rst", "python/_autosummary/mlx.core.fft.irfft.rst", "python/_autosummary/mlx.core.fft.irfft2.rst", "python/_autosummary/mlx.core.fft.irfftn.rst", "python/_autosummary/mlx.core.fft.rfft.rst", "python/_autosummary/mlx.core.fft.rfft2.rst", "python/_autosummary/mlx.core.fft.rfftn.rst", "python/_autosummary/mlx.core.flatten.rst", "python/_autosummary/mlx.core.floor.rst", "python/_autosummary/mlx.core.floor_divide.rst", "python/_autosummary/mlx.core.full.rst", "python/_autosummary/mlx.core.grad.rst", "python/_autosummary/mlx.core.greater.rst", "python/_autosummary/mlx.core.greater_equal.rst", "python/_autosummary/mlx.core.identity.rst", "python/_autosummary/mlx.core.inner.rst", "python/_autosummary/mlx.core.isinf.rst", "python/_autosummary/mlx.core.isnan.rst", "python/_autosummary/mlx.core.isneginf.rst", "python/_autosummary/mlx.core.isposinf.rst", "python/_autosummary/mlx.core.jvp.rst", "python/_autosummary/mlx.core.less.rst", "python/_autosummary/mlx.core.less_equal.rst", "python/_autosummary/mlx.core.linalg.norm.rst", "python/_autosummary/mlx.core.linalg.qr.rst", "python/_autosummary/mlx.core.linspace.rst", "python/_autosummary/mlx.core.load.rst", "python/_autosummary/mlx.core.log.rst", "python/_autosummary/mlx.core.log10.rst", "python/_autosummary/mlx.core.log1p.rst", "python/_autosummary/mlx.core.log2.rst", "python/_autosummary/mlx.core.logaddexp.rst", "python/_autosummary/mlx.core.logical_and.rst", "python/_autosummary/mlx.core.logical_not.rst", "python/_autosummary/mlx.core.logical_or.rst", "python/_autosummary/mlx.core.logsumexp.rst", "python/_autosummary/mlx.core.matmul.rst", "python/_autosummary/mlx.core.max.rst", "python/_autosummary/mlx.core.maximum.rst", "python/_autosummary/mlx.core.mean.rst", "python/_autosummary/mlx.core.min.rst", "python/_autosummary/mlx.core.minimum.rst", "python/_autosummary/mlx.core.moveaxis.rst", "python/_autosummary/mlx.core.multiply.rst", "python/_autosummary/mlx.core.negative.rst", "python/_autosummary/mlx.core.new_stream.rst", "python/_autosummary/mlx.core.ones.rst", "python/_autosummary/mlx.core.ones_like.rst", "python/_autosummary/mlx.core.outer.rst", "python/_autosummary/mlx.core.pad.rst", "python/_autosummary/mlx.core.partition.rst", "python/_autosummary/mlx.core.prod.rst", "python/_autosummary/mlx.core.quantize.rst", "python/_autosummary/mlx.core.quantized_matmul.rst", "python/_autosummary/mlx.core.random.bernoulli.rst", "python/_autosummary/mlx.core.random.categorical.rst", "python/_autosummary/mlx.core.random.gumbel.rst", "python/_autosummary/mlx.core.random.key.rst", "python/_autosummary/mlx.core.random.normal.rst", "python/_autosummary/mlx.core.random.randint.rst", "python/_autosummary/mlx.core.random.seed.rst", "python/_autosummary/mlx.core.random.split.rst", "python/_autosummary/mlx.core.random.truncated_normal.rst", "python/_autosummary/mlx.core.random.uniform.rst", "python/_autosummary/mlx.core.reciprocal.rst", "python/_autosummary/mlx.core.repeat.rst", "python/_autosummary/mlx.core.reshape.rst", "python/_autosummary/mlx.core.round.rst", "python/_autosummary/mlx.core.rsqrt.rst", "python/_autosummary/mlx.core.save.rst", "python/_autosummary/mlx.core.save_gguf.rst", "python/_autosummary/mlx.core.save_safetensors.rst", "python/_autosummary/mlx.core.savez.rst", "python/_autosummary/mlx.core.savez_compressed.rst", "python/_autosummary/mlx.core.set_default_device.rst", "python/_autosummary/mlx.core.set_default_stream.rst", "python/_autosummary/mlx.core.sigmoid.rst", "python/_autosummary/mlx.core.sign.rst", "python/_autosummary/mlx.core.sin.rst", "python/_autosummary/mlx.core.sinh.rst", "python/_autosummary/mlx.core.softmax.rst", "python/_autosummary/mlx.core.sort.rst", "python/_autosummary/mlx.core.split.rst", "python/_autosummary/mlx.core.sqrt.rst", "python/_autosummary/mlx.core.square.rst", "python/_autosummary/mlx.core.squeeze.rst", "python/_autosummary/mlx.core.stack.rst", "python/_autosummary/mlx.core.stop_gradient.rst", "python/_autosummary/mlx.core.stream.rst", "python/_autosummary/mlx.core.subtract.rst", "python/_autosummary/mlx.core.sum.rst", "python/_autosummary/mlx.core.swapaxes.rst", "python/_autosummary/mlx.core.take.rst", "python/_autosummary/mlx.core.take_along_axis.rst", "python/_autosummary/mlx.core.tan.rst", "python/_autosummary/mlx.core.tanh.rst", "python/_autosummary/mlx.core.tensordot.rst", "python/_autosummary/mlx.core.transpose.rst", "python/_autosummary/mlx.core.tri.rst", "python/_autosummary/mlx.core.tril.rst", "python/_autosummary/mlx.core.triu.rst", "python/_autosummary/mlx.core.value_and_grad.rst", "python/_autosummary/mlx.core.var.rst", "python/_autosummary/mlx.core.vjp.rst", "python/_autosummary/mlx.core.vmap.rst", "python/_autosummary/mlx.core.where.rst", "python/_autosummary/mlx.core.zeros.rst", "python/_autosummary/mlx.core.zeros_like.rst", "python/_autosummary/mlx.nn.value_and_grad.rst", "python/_autosummary/mlx.utils.tree_flatten.rst", "python/_autosummary/mlx.utils.tree_map.rst", "python/_autosummary/mlx.utils.tree_unflatten.rst", "python/_autosummary/stream_class.rst", "python/array.rst", "python/data_types.rst", "python/devices_and_streams.rst", "python/fft.rst", "python/linalg.rst", "python/nn.rst", "python/nn/_autosummary/mlx.nn.ALiBi.rst", "python/nn/_autosummary/mlx.nn.AvgPool1d.rst", "python/nn/_autosummary/mlx.nn.AvgPool2d.rst", "python/nn/_autosummary/mlx.nn.BatchNorm.rst", "python/nn/_autosummary/mlx.nn.Conv1d.rst", "python/nn/_autosummary/mlx.nn.Conv2d.rst", "python/nn/_autosummary/mlx.nn.Dropout.rst", "python/nn/_autosummary/mlx.nn.Dropout2d.rst", "python/nn/_autosummary/mlx.nn.Dropout3d.rst", "python/nn/_autosummary/mlx.nn.Embedding.rst", "python/nn/_autosummary/mlx.nn.GELU.rst", "python/nn/_autosummary/mlx.nn.GroupNorm.rst", "python/nn/_autosummary/mlx.nn.InstanceNorm.rst", "python/nn/_autosummary/mlx.nn.LayerNorm.rst", "python/nn/_autosummary/mlx.nn.Linear.rst", "python/nn/_autosummary/mlx.nn.MaxPool1d.rst", "python/nn/_autosummary/mlx.nn.MaxPool2d.rst", "python/nn/_autosummary/mlx.nn.Mish.rst", "python/nn/_autosummary/mlx.nn.Module.apply.rst", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules.rst", "python/nn/_autosummary/mlx.nn.Module.children.rst", "python/nn/_autosummary/mlx.nn.Module.eval.rst", "python/nn/_autosummary/mlx.nn.Module.filter_and_map.rst", "python/nn/_autosummary/mlx.nn.Module.freeze.rst", "python/nn/_autosummary/mlx.nn.Module.leaf_modules.rst", "python/nn/_autosummary/mlx.nn.Module.load_weights.rst", "python/nn/_autosummary/mlx.nn.Module.modules.rst", "python/nn/_autosummary/mlx.nn.Module.named_modules.rst", "python/nn/_autosummary/mlx.nn.Module.parameters.rst", "python/nn/_autosummary/mlx.nn.Module.save_weights.rst", "python/nn/_autosummary/mlx.nn.Module.state.rst", "python/nn/_autosummary/mlx.nn.Module.train.rst", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters.rst", "python/nn/_autosummary/mlx.nn.Module.training.rst", "python/nn/_autosummary/mlx.nn.Module.unfreeze.rst", "python/nn/_autosummary/mlx.nn.Module.update.rst", "python/nn/_autosummary/mlx.nn.Module.update_modules.rst", "python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst", "python/nn/_autosummary/mlx.nn.PReLU.rst", "python/nn/_autosummary/mlx.nn.QuantizedLinear.rst", "python/nn/_autosummary/mlx.nn.RMSNorm.rst", "python/nn/_autosummary/mlx.nn.ReLU.rst", "python/nn/_autosummary/mlx.nn.RoPE.rst", "python/nn/_autosummary/mlx.nn.SELU.rst", "python/nn/_autosummary/mlx.nn.Sequential.rst", "python/nn/_autosummary/mlx.nn.SiLU.rst", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst", "python/nn/_autosummary/mlx.nn.Softshrink.rst", "python/nn/_autosummary/mlx.nn.Step.rst", "python/nn/_autosummary/mlx.nn.Transformer.rst", "python/nn/_autosummary/mlx.nn.init.constant.rst", "python/nn/_autosummary/mlx.nn.init.glorot_normal.rst", "python/nn/_autosummary/mlx.nn.init.glorot_uniform.rst", "python/nn/_autosummary/mlx.nn.init.he_normal.rst", "python/nn/_autosummary/mlx.nn.init.he_uniform.rst", "python/nn/_autosummary/mlx.nn.init.identity.rst", "python/nn/_autosummary/mlx.nn.init.normal.rst", "python/nn/_autosummary/mlx.nn.init.uniform.rst", "python/nn/_autosummary_functions/mlx.nn.gelu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst", "python/nn/_autosummary_functions/mlx.nn.mish.rst", "python/nn/_autosummary_functions/mlx.nn.prelu.rst", "python/nn/_autosummary_functions/mlx.nn.relu.rst", "python/nn/_autosummary_functions/mlx.nn.selu.rst", "python/nn/_autosummary_functions/mlx.nn.silu.rst", "python/nn/_autosummary_functions/mlx.nn.softshrink.rst", "python/nn/_autosummary_functions/mlx.nn.step.rst", "python/nn/functions.rst", "python/nn/init.rst", "python/nn/layers.rst", "python/nn/losses.rst", "python/nn/module.rst", "python/ops.rst", "python/optimizers.rst", "python/optimizers/_autosummary/mlx.optimizers.AdaDelta.rst", "python/optimizers/_autosummary/mlx.optimizers.Adafactor.rst", "python/optimizers/_autosummary/mlx.optimizers.Adagrad.rst", "python/optimizers/_autosummary/mlx.optimizers.Adam.rst", "python/optimizers/_autosummary/mlx.optimizers.AdamW.rst", "python/optimizers/_autosummary/mlx.optimizers.Adamax.rst", "python/optimizers/_autosummary/mlx.optimizers.Lion.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.init.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.state.rst", "python/optimizers/_autosummary/mlx.optimizers.Optimizer.update.rst", "python/optimizers/_autosummary/mlx.optimizers.RMSprop.rst", "python/optimizers/_autosummary/mlx.optimizers.SGD.rst", "python/optimizers/_autosummary/mlx.optimizers.cosine_decay.rst", "python/optimizers/_autosummary/mlx.optimizers.exponential_decay.rst", "python/optimizers/_autosummary/mlx.optimizers.step_decay.rst", "python/optimizers/common_optimizers.rst", "python/optimizers/optimizer.rst", "python/optimizers/schedulers.rst", "python/random.rst", "python/transforms.rst", "python/tree_utils.rst", "usage/compile.rst", "usage/function_transforms.rst", "usage/indexing.rst", "usage/lazy_evaluation.rst", "usage/numpy.rst", "usage/quick_start.rst", "usage/saving_and_loading.rst", "usage/unified_memory.rst", "usage/using_streams.rst"], "titles": ["Operations", "Developer Documentation", "Linear Regression", "LLM inference", "Multi-Layer Perceptron", "MLX", "Build and Install", "mlx.core.Device", "mlx.core.Dtype", "mlx.core.abs", "mlx.core.add", "mlx.core.all", "mlx.core.allclose", "mlx.core.any", "mlx.core.arange", "mlx.core.arccos", "mlx.core.arccosh", "mlx.core.arcsin", "mlx.core.arcsinh", "mlx.core.arctan", "mlx.core.arctanh", "mlx.core.argmax", "mlx.core.argmin", "mlx.core.argpartition", "mlx.core.argsort", "mlx.core.array", "mlx.core.array.T", "mlx.core.array.abs", "mlx.core.array.all", "mlx.core.array.any", "mlx.core.array.argmax", "mlx.core.array.argmin", "mlx.core.array.astype", "mlx.core.array.cos", "mlx.core.array.dtype", "mlx.core.array.exp", "mlx.core.array.item", "mlx.core.array.log", "mlx.core.array.log1p", "mlx.core.array.logsumexp", "mlx.core.array.max", "mlx.core.array.mean", "mlx.core.array.min", "mlx.core.array.ndim", "mlx.core.array.prod", "mlx.core.array.reciprocal", "mlx.core.array.reshape", "mlx.core.array.round", "mlx.core.array.rsqrt", "mlx.core.array.shape", "mlx.core.array.sin", "mlx.core.array.size", "mlx.core.array.split", "mlx.core.array.sqrt", "mlx.core.array.square", "mlx.core.array.sum", "mlx.core.array.tolist", "mlx.core.array.transpose", "mlx.core.array.var", "mlx.core.array_equal", "mlx.core.broadcast_to", "mlx.core.ceil", "mlx.core.clip", "mlx.core.compile", "mlx.core.concatenate", "mlx.core.conv1d", "mlx.core.conv2d", "mlx.core.convolve", "mlx.core.cos", "mlx.core.cosh", "mlx.core.default_device", "mlx.core.default_stream", "mlx.core.dequantize", "mlx.core.diag", "mlx.core.diagonal", "mlx.core.disable_compile", "mlx.core.divide", "mlx.core.divmod", "mlx.core.enable_compile", "mlx.core.equal", "mlx.core.erf", "mlx.core.erfinv", "mlx.core.eval", "mlx.core.exp", "mlx.core.expand_dims", "mlx.core.eye", "mlx.core.fft.fft", "mlx.core.fft.fft2", "mlx.core.fft.fftn", "mlx.core.fft.ifft", "mlx.core.fft.ifft2", "mlx.core.fft.ifftn", "mlx.core.fft.irfft", "mlx.core.fft.irfft2", "mlx.core.fft.irfftn", "mlx.core.fft.rfft", "mlx.core.fft.rfft2", "mlx.core.fft.rfftn", "mlx.core.flatten", "mlx.core.floor", "mlx.core.floor_divide", "mlx.core.full", "mlx.core.grad", "mlx.core.greater", "mlx.core.greater_equal", "mlx.core.identity", "mlx.core.inner", "mlx.core.isinf", "mlx.core.isnan", "mlx.core.isneginf", "mlx.core.isposinf", "mlx.core.jvp", "mlx.core.less", "mlx.core.less_equal", "mlx.core.linalg.norm", "mlx.core.linalg.qr", "mlx.core.linspace", "mlx.core.load", "mlx.core.log", "mlx.core.log10", "mlx.core.log1p", "mlx.core.log2", "mlx.core.logaddexp", "mlx.core.logical_and", "mlx.core.logical_not", "mlx.core.logical_or", "mlx.core.logsumexp", "mlx.core.matmul", "mlx.core.max", "mlx.core.maximum", "mlx.core.mean", "mlx.core.min", "mlx.core.minimum", "mlx.core.moveaxis", "mlx.core.multiply", "mlx.core.negative", "mlx.core.new_stream", "mlx.core.ones", "mlx.core.ones_like", "mlx.core.outer", "mlx.core.pad", "mlx.core.partition", "mlx.core.prod", "mlx.core.quantize", "mlx.core.quantized_matmul", "mlx.core.random.bernoulli", "mlx.core.random.categorical", "mlx.core.random.gumbel", "mlx.core.random.key", "mlx.core.random.normal", "mlx.core.random.randint", "mlx.core.random.seed", "mlx.core.random.split", "mlx.core.random.truncated_normal", "mlx.core.random.uniform", "mlx.core.reciprocal", "mlx.core.repeat", "mlx.core.reshape", "mlx.core.round", "mlx.core.rsqrt", "mlx.core.save", "mlx.core.save_gguf", "mlx.core.save_safetensors", "mlx.core.savez", "mlx.core.savez_compressed", "mlx.core.set_default_device", "mlx.core.set_default_stream", "mlx.core.sigmoid", "mlx.core.sign", "mlx.core.sin", "mlx.core.sinh", "mlx.core.softmax", "mlx.core.sort", "mlx.core.split", "mlx.core.sqrt", "mlx.core.square", "mlx.core.squeeze", "mlx.core.stack", "mlx.core.stop_gradient", "mlx.core.stream", "mlx.core.subtract", "mlx.core.sum", "mlx.core.swapaxes", "mlx.core.take", "mlx.core.take_along_axis", "mlx.core.tan", "mlx.core.tanh", "mlx.core.tensordot", "mlx.core.transpose", "mlx.core.tri", "mlx.core.tril", "mlx.core.triu", "mlx.core.value_and_grad", "mlx.core.var", "mlx.core.vjp", "mlx.core.vmap", "mlx.core.where", "mlx.core.zeros", "mlx.core.zeros_like", "mlx.nn.value_and_grad", "mlx.utils.tree_flatten", "mlx.utils.tree_map", "mlx.utils.tree_unflatten", "mlx.core.Stream", "Array", "Data Types", "Devices and Streams", "FFT", "Linear Algebra", "Neural Networks", "mlx.nn.ALiBi", "mlx.nn.AvgPool1d", "mlx.nn.AvgPool2d", "mlx.nn.BatchNorm", "mlx.nn.Conv1d", "mlx.nn.Conv2d", "mlx.nn.Dropout", "mlx.nn.Dropout2d", "mlx.nn.Dropout3d", "mlx.nn.Embedding", "mlx.nn.GELU", "mlx.nn.GroupNorm", "mlx.nn.InstanceNorm", "mlx.nn.LayerNorm", "mlx.nn.Linear", "mlx.nn.MaxPool1d", "mlx.nn.MaxPool2d", "mlx.nn.Mish", "mlx.nn.Module.apply", "mlx.nn.Module.apply_to_modules", "mlx.nn.Module.children", "mlx.nn.Module.eval", "mlx.nn.Module.filter_and_map", "mlx.nn.Module.freeze", "mlx.nn.Module.leaf_modules", "mlx.nn.Module.load_weights", "mlx.nn.Module.modules", "mlx.nn.Module.named_modules", "mlx.nn.Module.parameters", "mlx.nn.Module.save_weights", "mlx.nn.Module.state", "mlx.nn.Module.train", "mlx.nn.Module.trainable_parameters", "mlx.nn.Module.training", "mlx.nn.Module.unfreeze", "mlx.nn.Module.update", "mlx.nn.Module.update_modules", "mlx.nn.MultiHeadAttention", "mlx.nn.PReLU", "mlx.nn.QuantizedLinear", "mlx.nn.RMSNorm", "mlx.nn.ReLU", "mlx.nn.RoPE", "mlx.nn.SELU", "mlx.nn.Sequential", "mlx.nn.SiLU", "mlx.nn.SinusoidalPositionalEncoding", "mlx.nn.Softshrink", "mlx.nn.Step", "mlx.nn.Transformer", "mlx.nn.init.constant", "mlx.nn.init.glorot_normal", "mlx.nn.init.glorot_uniform", "mlx.nn.init.he_normal", "mlx.nn.init.he_uniform", "mlx.nn.init.identity", "mlx.nn.init.normal", "mlx.nn.init.uniform", "mlx.nn.gelu", "mlx.nn.gelu_approx", "mlx.nn.gelu_fast_approx", "mlx.nn.losses.binary_cross_entropy", "mlx.nn.losses.cosine_similarity_loss", "mlx.nn.losses.cross_entropy", "mlx.nn.losses.gaussian_nll_loss", "mlx.nn.losses.hinge_loss", "mlx.nn.losses.huber_loss", "mlx.nn.losses.kl_div_loss", "mlx.nn.losses.l1_loss", "mlx.nn.losses.log_cosh_loss", "mlx.nn.losses.margin_ranking_loss", "mlx.nn.losses.mse_loss", "mlx.nn.losses.nll_loss", "mlx.nn.losses.smooth_l1_loss", "mlx.nn.losses.triplet_loss", "mlx.nn.mish", "mlx.nn.prelu", "mlx.nn.relu", "mlx.nn.selu", "mlx.nn.silu", "mlx.nn.softshrink", "mlx.nn.step", "Functions", "Initializers", "Layers", "Loss Functions", "Module", "Operations", "Optimizers", "mlx.optimizers.AdaDelta", "mlx.optimizers.Adafactor", "mlx.optimizers.Adagrad", "mlx.optimizers.Adam", "mlx.optimizers.AdamW", "mlx.optimizers.Adamax", "mlx.optimizers.Lion", "mlx.optimizers.Optimizer.apply_gradients", "mlx.optimizers.Optimizer.init", "mlx.optimizers.Optimizer.state", "mlx.optimizers.Optimizer.update", "mlx.optimizers.RMSprop", "mlx.optimizers.SGD", "mlx.optimizers.cosine_decay", "mlx.optimizers.exponential_decay", "mlx.optimizers.step_decay", "Common Optimizers", "Optimizer", "Schedulers", "Random", "Transforms", "Tree Utils", "Compilation", "Function Transforms", "Indexing Arrays", "Lazy Evaluation", "Conversion to NumPy and Other Frameworks", "Quick Start Guide", "Saving and Loading Arrays", "Unified Memory", "Using Streams"], "terms": {"mlx": [1, 2, 3, 4, 6, 209, 293, 296, 298, 316, 318, 320, 321, 322, 323, 324, 325, 326, 327, 328], "provid": [1, 3, 72, 102, 187, 192, 201, 209, 228, 233, 235, 244, 245, 246, 249, 259, 292, 296, 327, 329], "open": [1, 6, 14, 150, 154], "flexibl": [1, 5, 246], "which": [1, 3, 4, 5, 6, 14, 32, 63, 74, 82, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 102, 107, 108, 109, 110, 111, 114, 115, 117, 143, 146, 147, 156, 157, 160, 161, 162, 163, 164, 176, 177, 183, 192, 194, 195, 212, 217, 218, 220, 226, 228, 232, 252, 273, 276, 280, 283, 293, 306, 307, 318, 321, 322, 323, 324, 328, 329], "user": [1, 3, 209], "mai": [1, 114, 217, 322, 323], "add": [1, 3, 84, 122, 140, 143, 214, 215, 322, 328], "special": 1, "without": [1, 3, 5, 178, 247, 292, 320, 321, 324, 325, 328], "much": [1, 3, 211, 212, 225, 226, 321, 324], "hassl": 1, "while": [1, 3, 6, 157, 252, 324, 325], "librari": [1, 6, 209], "suppli": 1, "effici": [1, 3, 5, 217, 252, 324, 326], "can": [1, 3, 5, 6, 10, 14, 46, 57, 63, 74, 75, 76, 77, 79, 82, 103, 104, 112, 113, 114, 122, 129, 132, 134, 145, 146, 150, 153, 154, 161, 180, 192, 209, 212, 219, 226, 232, 244, 254, 273, 293, 296, 298, 306, 307, 318, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329], "compos": [1, 5, 209, 321, 322, 326], "ani": [1, 3, 5, 14, 200, 201, 202, 209, 220, 228, 229, 232, 240, 249, 259, 293, 320, 321, 322, 324, 326, 327, 328], "number": [1, 14, 51, 63, 66, 72, 85, 102, 105, 111, 116, 140, 143, 144, 146, 149, 152, 154, 156, 158, 187, 189, 192, 194, 195, 209, 213, 214, 215, 217, 218, 221, 222, 247, 248, 259, 261, 262, 263, 264, 312, 318, 321, 322, 329], "applic": [1, 6], "aris": [1, 325], "case": [1, 3, 88, 91, 92, 94, 95, 96, 97, 98, 115, 127, 157, 176, 212, 217, 226, 253, 258, 283, 288, 290, 291, 306, 307, 321, 322, 326, 327, 328, 329], "where": [1, 4, 85, 143, 192, 195, 211, 212, 213, 214, 215, 216, 217, 218, 220, 221, 222, 223, 224, 225, 226, 232, 248, 250, 253, 255, 258, 263, 264, 268, 269, 270, 274, 280, 286, 288, 289, 291, 307, 322, 323], "new": [1, 4, 60, 74, 133, 136, 157, 177, 188, 201, 247, 296, 298, 309, 321, 323, 324, 325], "function": [1, 2, 3, 4, 5, 12, 63, 77, 80, 81, 102, 111, 114, 115, 127, 167, 192, 194, 195, 199, 201, 209, 220, 227, 229, 233, 244, 248, 254, 257, 258, 259, 268, 269, 270, 285, 290, 291, 293, 298, 307, 318, 320, 323, 324, 325, 327], "highli": [1, 6], "optim": [1, 2, 4, 5, 245, 321, 322, 324], "ar": [1, 2, 3, 4, 5, 6, 12, 14, 59, 60, 62, 63, 67, 74, 82, 85, 87, 88, 90, 91, 93, 94, 96, 97, 98, 102, 107, 108, 109, 110, 111, 114, 115, 117, 127, 139, 140, 141, 143, 144, 145, 146, 147, 150, 153, 154, 163, 164, 176, 177, 183, 192, 194, 195, 200, 201, 213, 214, 215, 216, 217, 218, 221, 222, 223, 224, 235, 247, 249, 271, 273, 274, 292, 296, 305, 307, 320, 321, 322, 323, 324, 325, 326, 327, 328], "need": [1, 3, 4, 5, 59, 143, 209, 245, 246, 256, 259, 318, 322, 324, 325, 326, 328], "For": [1, 3, 6, 114, 143, 202, 209, 213, 217, 228, 233, 241, 244, 249, 252, 256, 261, 262, 263, 264, 293, 318, 321, 322, 323, 324, 325, 326, 327, 328], "you": [1, 3, 4, 5, 6, 209, 256, 259, 293, 318, 321, 322, 323, 325, 327, 328], "design": [1, 2, 5, 318, 328], "your": [1, 3, 6, 296, 322, 324], "own": [1, 6, 325], "link": [1, 6], "top": [1, 224], "core": [1, 2, 3, 4, 209, 211, 212, 213, 222, 225, 226, 235, 238, 242, 260, 261, 262, 263, 264, 265, 266, 267, 271, 273, 280, 293, 296, 298, 321, 325, 326], "we": [1, 2, 3, 4, 72, 143, 144, 209, 219, 254, 303, 305, 318, 320, 321, 322, 324, 328], "inner": [1, 321], "work": [1, 3, 6, 321, 322, 323, 324], "go": [1, 3, 322], "over": [1, 3, 4, 11, 13, 21, 22, 23, 24, 65, 66, 88, 91, 94, 97, 106, 114, 116, 126, 128, 130, 131, 141, 142, 159, 171, 172, 181, 187, 193, 213, 214, 215, 221, 223, 250, 273, 312, 322], "simpl": [1, 3, 4, 209, 219, 292, 321, 322, 324], "learn": [1, 2, 4, 5, 213, 221, 222, 223, 248, 250, 299, 300, 301, 302, 303, 304, 305, 310, 311], "step": [1, 3, 4, 14, 209, 300, 307, 312, 314, 321], "involv": [1, 298, 321], "ad": [1, 2, 6, 222, 296, 299, 300, 301, 302, 303, 304, 310, 324, 327], "let": [1, 2, 3, 321, 322, 324, 325], "s": [1, 2, 3, 4, 34, 43, 63, 71, 72, 87, 88, 90, 91, 93, 94, 96, 97, 102, 114, 117, 130, 139, 143, 146, 158, 161, 162, 179, 192, 193, 195, 199, 209, 212, 226, 232, 233, 235, 239, 240, 244, 298, 307, 308, 318, 321, 322, 324, 325, 326, 327, 328], "sai": [1, 3, 293, 324], "would": [1, 3, 323, 324, 325, 328], "like": [1, 3, 5, 138, 198, 218, 279, 307, 309, 321, 322, 324, 325, 326, 328], "an": [1, 3, 4, 6, 8, 11, 13, 25, 60, 65, 66, 82, 85, 98, 101, 105, 114, 117, 128, 131, 133, 137, 138, 140, 142, 143, 144, 156, 157, 158, 173, 176, 182, 183, 184, 187, 189, 195, 197, 198, 200, 201, 209, 211, 212, 216, 221, 223, 224, 225, 226, 228, 247, 248, 249, 259, 260, 261, 262, 263, 264, 265, 266, 267, 269, 286, 293, 299, 309, 313, 316, 318, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329], "take": [1, 3, 4, 63, 102, 111, 129, 132, 138, 144, 184, 192, 194, 195, 198, 247, 318, 322, 323, 327, 328, 329], "two": [1, 10, 12, 59, 74, 76, 79, 87, 90, 96, 103, 104, 112, 113, 115, 122, 127, 129, 132, 134, 139, 182, 212, 226, 249, 272, 321, 322, 323, 328], "arrai": [1, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 209, 213, 228, 235, 238, 242, 248, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 291, 293, 296, 311, 312, 313, 314, 321, 322, 324, 325, 326, 328], "x": [1, 2, 3, 4, 80, 105, 114, 144, 147, 158, 163, 167, 190, 191, 196, 201, 209, 211, 212, 213, 220, 221, 222, 223, 224, 225, 226, 227, 228, 248, 250, 251, 253, 255, 256, 258, 268, 269, 270, 283, 285, 286, 287, 288, 289, 290, 291, 296, 298, 305, 321, 322, 323, 324, 325, 326, 328], "y": [1, 2, 3, 4, 196, 209, 213, 217, 221, 222, 223, 224, 250, 275, 280, 283, 298, 301, 321, 322, 324, 325], "scale": [1, 3, 72, 143, 144, 149, 217, 218, 247, 252, 253, 256, 288, 300], "them": [1, 3, 209, 233, 244, 328], "both": [1, 10, 76, 77, 79, 103, 104, 112, 113, 114, 122, 129, 132, 134, 146, 180, 211, 212, 222, 225, 226, 298, 321, 322, 326, 328], "some": [1, 2, 3, 4, 233, 244, 307, 321, 322, 324], "coeffici": [1, 299, 300, 302, 303, 304, 305], "alpha": [1, 143, 253, 284, 286, 288, 303, 310], "beta": [1, 72, 143, 213, 221, 222, 223, 283, 302, 303, 304, 305], "respect": [1, 2, 4, 102, 143, 192, 201, 209, 213, 220, 221, 222, 223, 296, 322, 326], "togeth": [1, 4, 143, 201], "get": [1, 2, 4, 6, 66, 70, 71, 148, 209, 321, 322, 324, 328], "z": [1, 321, 324], "well": [1, 3, 209, 233, 244, 247, 324], "veri": [1, 3, 247, 324, 328], "easili": 1, "do": [1, 3, 6, 209, 234, 244, 293, 296, 303, 321, 322, 324], "just": [1, 4, 321, 323], "write": [1, 3, 209, 325], "out": [1, 6, 211, 212, 217, 218, 225, 226, 241, 321, 322, 323], "follow": [1, 3, 4, 5, 6, 14, 67, 72, 114, 143, 209, 269, 270, 277, 299, 300, 301, 302, 303, 304, 305, 311, 318, 321, 322, 328], "import": [1, 2, 3, 4, 6, 114, 163, 192, 200, 201, 202, 209, 211, 212, 213, 222, 225, 226, 235, 271, 273, 280, 293, 296, 321, 322, 323, 324, 325, 326], "mx": [1, 2, 3, 4, 98, 114, 115, 117, 163, 192, 209, 211, 212, 213, 222, 225, 226, 228, 235, 239, 251, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 277, 280, 287, 293, 296, 298, 318, 321, 322, 323, 324, 325, 326, 327, 328, 329], "def": [1, 2, 3, 4, 192, 209, 296, 321, 322, 323, 324, 325, 328], "simple_axpbi": 1, "float": [1, 12, 14, 56, 100, 101, 114, 144, 145, 149, 150, 153, 154, 205, 213, 216, 217, 218, 221, 222, 223, 228, 250, 252, 256, 258, 259, 260, 261, 262, 263, 264, 266, 267, 272, 273, 274, 276, 280, 283, 284, 290, 291, 299, 300, 301, 302, 303, 304, 305, 310, 311, 312, 313, 314], "return": [1, 2, 3, 4, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 32, 36, 49, 56, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 152, 153, 154, 155, 156, 157, 158, 159, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 209, 230, 232, 234, 236, 237, 238, 242, 249, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 293, 296, 306, 320, 321, 322, 323, 324, 325, 327, 328], "thi": [1, 3, 4, 6, 11, 12, 13, 14, 21, 22, 23, 24, 78, 111, 114, 115, 122, 126, 127, 128, 130, 131, 141, 142, 146, 166, 171, 172, 173, 181, 183, 193, 209, 216, 217, 218, 229, 230, 232, 233, 236, 237, 238, 242, 244, 245, 246, 247, 249, 258, 261, 262, 263, 264, 269, 270, 279, 291, 296, 307, 320, 321, 322, 324, 325, 327], "perform": [1, 3, 5, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 127, 144, 158, 171, 183, 209, 221, 259, 263, 264, 321, 323, 324, 328], "leav": [1, 82, 201], "differenti": [1, 5], "howev": [1, 209, 220, 221, 307, 318, 321, 324, 325], "vector": [1, 2, 5, 106, 111, 114, 183, 194, 195, 219, 273, 326], "math": [1, 3, 284, 321], "often": [1, 218], "realiz": 1, "axpbi": 1, "routin": 1, "defin": [1, 2, 3, 4, 6, 114, 144, 200, 325], "same": [1, 3, 6, 12, 59, 60, 63, 66, 67, 92, 95, 96, 97, 102, 111, 140, 146, 158, 194, 196, 209, 212, 213, 216, 221, 222, 226, 249, 260, 261, 262, 263, 264, 265, 266, 267, 273, 284, 296, 306, 318, 321, 323, 328], "realli": 1, "part": [1, 322, 323], "doe": [1, 3, 6, 209, 321, 323, 324, 325], "fast": [1, 220, 270, 328], "so": [1, 3, 6, 102, 192, 216, 298, 321, 324, 328], "decid": [1, 201, 232], "want": [1, 3, 322, 328], "reli": 1, "acceler": [1, 213], "framework": [1, 5], "continu": [1, 322], "impos": 1, "our": [1, 3, 4, 254, 299, 300, 301, 302, 304, 305], "assumpt": 1, "also": [1, 3, 4, 5, 6, 10, 75, 76, 77, 79, 88, 91, 94, 97, 103, 104, 112, 113, 122, 129, 132, 134, 143, 180, 199, 209, 232, 245, 247, 249, 253, 255, 268, 288, 289, 292, 298, 321, 322, 323, 324, 325, 326, 329], "assum": [1, 3, 115, 201, 209, 211, 212, 221, 225, 226], "how": [1, 3, 4, 209, 211, 212, 214, 215, 219, 225, 226, 306, 321, 323, 328], "gradient": [1, 2, 4, 102, 178, 192, 199, 209, 233, 245, 249, 259, 279, 296, 298, 299, 300, 302, 303, 304, 305, 306, 309, 311, 321, 322, 323, 324, 325, 326], "ins": 1, "what": [1, 3, 201], "coincid": 1, "right": [1, 6, 143, 211, 212, 220, 225, 226, 269, 270, 274, 276, 284], "place": [1, 3, 158, 324, 325], "cours": [1, 322], "The": [1, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 32, 34, 43, 49, 56, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 152, 153, 154, 155, 156, 157, 161, 162, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 205, 211, 212, 213, 214, 215, 216, 217, 218, 219, 221, 222, 223, 224, 225, 226, 229, 235, 239, 240, 245, 246, 247, 249, 250, 252, 254, 256, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 291, 293, 296, 298, 299, 300, 301, 302, 303, 304, 305, 308, 310, 311, 312, 316, 321, 322, 323, 324, 325, 326, 327, 328, 329], "structur": [1, 306, 322], "from": [1, 3, 4, 5, 72, 74, 93, 94, 96, 97, 101, 114, 117, 127, 138, 143, 145, 146, 147, 148, 150, 153, 163, 176, 178, 180, 183, 184, 196, 198, 200, 201, 202, 209, 224, 233, 235, 247, 261, 262, 263, 264, 266, 267, 274, 283, 293, 320, 321, 322, 324, 325, 326, 327, 328], "frontend": 1, "api": [1, 322], "redirect": 1, "when": [1, 3, 5, 6, 114, 117, 214, 215, 263, 264, 277, 283, 296, 318, 321, 328], "appropri": [1, 321], "fallback": 1, "metal": 1, "vjp": [1, 326], "jvp": [1, 326], "In": [1, 3, 4, 127, 143, 201, 209, 217, 221, 296, 299, 301, 302, 304, 305, 306, 320, 321, 322, 324, 327, 328], "one": [1, 3, 6, 56, 62, 66, 84, 85, 114, 120, 127, 144, 146, 176, 180, 244, 273, 328], "sentenc": 1, "comput": [1, 2, 3, 4, 5, 6, 72, 102, 111, 114, 122, 130, 139, 143, 171, 178, 187, 192, 193, 194, 199, 209, 213, 221, 222, 223, 233, 245, 249, 250, 252, 259, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 298, 299, 300, 302, 303, 304, 305, 309, 321, 322, 326, 328], "graph": [1, 3, 4, 5, 322], "rule": 1, "evalu": [1, 3, 4, 5, 82, 111, 194, 209, 231, 241, 296, 298, 321, 326], "said": [1, 3], "start": [1, 2, 3, 5, 6, 14, 116, 173, 321, 323, 328], "discuss": 1, "more": [1, 4, 8, 56, 74, 127, 161, 162, 209, 213, 217, 252, 256, 259, 261, 262, 263, 264, 318, 321, 322, 323, 326, 328], "detail": [1, 8, 209, 217, 252, 256, 261, 262, 263, 264, 299, 301, 302, 304, 305, 323, 326], "thei": [1, 2, 3, 12, 67, 254, 275, 296, 305, 320, 321, 324, 326, 327, 328], "c": [1, 3, 114, 205, 211, 212, 213, 214, 215, 217, 218, 222, 225, 226, 325, 326, 328], "scalar": [1, 10, 12, 25, 36, 56, 59, 60, 62, 76, 77, 79, 100, 101, 102, 103, 104, 112, 113, 114, 116, 122, 123, 124, 125, 127, 129, 132, 134, 140, 150, 153, 154, 161, 180, 192, 196, 199, 284, 322, 324, 326], "sum": [1, 2, 10, 106, 114, 126, 171, 187, 209, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 323, 325], "element": [1, 9, 10, 15, 16, 17, 18, 19, 20, 23, 51, 61, 68, 69, 72, 76, 77, 79, 80, 81, 83, 85, 99, 100, 103, 104, 107, 108, 109, 110, 112, 113, 118, 119, 120, 121, 122, 123, 124, 125, 129, 132, 134, 135, 141, 143, 144, 155, 156, 159, 167, 168, 169, 170, 174, 175, 180, 183, 185, 186, 192, 196, 216, 217, 218, 227, 248, 252, 255, 285, 286, 289, 321, 322], "wise": [1, 9, 10, 15, 16, 17, 18, 19, 20, 61, 68, 69, 76, 77, 79, 80, 81, 83, 99, 100, 103, 104, 112, 113, 118, 119, 120, 121, 122, 123, 124, 125, 129, 132, 134, 135, 155, 159, 167, 168, 169, 170, 174, 175, 180, 185, 186, 217, 218, 227, 248, 255, 285, 286, 289, 321], "numpi": [1, 3, 4, 5, 10, 12, 14, 60, 76, 77, 79, 103, 104, 112, 113, 122, 127, 129, 132, 134, 180, 324, 326, 327], "style": [1, 10, 12, 76, 77, 79, 103, 104, 112, 113, 122, 127, 129, 132, 134, 180], "broadcast": [1, 10, 12, 60, 62, 76, 77, 79, 101, 103, 104, 112, 113, 122, 127, 129, 132, 134, 145, 146, 153, 154, 180, 184, 196, 247], "between": [1, 5, 62, 98, 259, 272, 275, 276, 279, 324, 328], "input": [1, 2, 3, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 73, 74, 76, 77, 79, 80, 81, 83, 84, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 102, 103, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 138, 139, 140, 141, 142, 143, 144, 152, 155, 156, 157, 158, 159, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 191, 192, 193, 195, 196, 198, 211, 212, 213, 214, 215, 217, 218, 219, 221, 222, 223, 224, 225, 226, 247, 249, 250, 252, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 274, 275, 276, 277, 279, 280, 282, 284, 291, 293, 321, 322, 323, 326, 327], "upcast": 1, "const": [1, 274], "factor": [1, 115, 273, 313, 314], "streamordevic": 1, "stream": [1, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 27, 28, 29, 30, 31, 32, 33, 35, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 149, 150, 152, 153, 154, 155, 156, 157, 158, 159, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 196, 197, 198, 328], "schedul": [1, 298, 312, 313, 314, 316, 328], "itself": [1, 307], "call": [1, 3, 4, 26, 100, 209, 219, 233, 244, 254, 296, 298, 307, 321, 322, 324], "other": [1, 3, 5, 114, 209, 234, 296, 305, 321, 323, 324, 326], "within": [1, 23], "simplest": [1, 209], "wai": [1, 3, 6, 209, 321, 322, 323], "about": [1, 3, 4, 324, 328], "term": [1, 274, 299, 300, 301, 302, 303, 304, 310], "exist": [1, 3, 233, 244], "auto": [1, 6], "ax": [1, 11, 13, 21, 22, 57, 84, 87, 88, 90, 91, 93, 94, 96, 97, 98, 106, 114, 126, 128, 130, 131, 140, 142, 171, 176, 181, 182, 187, 188, 193, 322], "multipli": [1, 143, 144, 216, 256], "earlier": 1, "goal": 1, "themselv": [1, 321], "contain": [1, 3, 23, 24, 49, 63, 74, 92, 93, 94, 114, 123, 124, 125, 143, 173, 196, 209, 232, 234, 235, 240, 259, 280, 293, 296, 321, 322], "act": [1, 279], "data": [1, 4, 5, 8, 14, 85, 95, 96, 101, 105, 116, 137, 153, 189, 197, 218, 260, 261, 262, 263, 264, 265, 266, 267, 321, 323, 325], "nor": [1, 102, 192], "rather": [1, 322, 328], "easi": [1, 209], "interfac": 1, "block": [1, 3, 259], "A": [1, 3, 5, 6, 7, 49, 59, 63, 102, 111, 114, 115, 117, 126, 127, 143, 145, 146, 147, 149, 150, 153, 154, 173, 177, 179, 192, 194, 195, 199, 200, 201, 202, 203, 209, 213, 217, 221, 222, 223, 227, 232, 236, 237, 245, 246, 250, 254, 256, 259, 261, 262, 264, 270, 284, 285, 296, 298, 302, 304, 306, 307, 309, 321, 322, 324, 325], "It": [1, 3, 6, 102, 166, 192, 209, 246, 249, 306, 316, 325, 327], "creat": [1, 3, 6, 85, 105, 179, 209, 296, 298, 321, 323, 325], "output": [1, 3, 6, 11, 12, 13, 14, 23, 60, 63, 85, 92, 95, 96, 97, 101, 102, 105, 114, 116, 126, 128, 130, 131, 137, 138, 141, 142, 145, 146, 147, 149, 150, 153, 154, 163, 164, 171, 176, 181, 184, 189, 192, 193, 194, 195, 196, 197, 198, 211, 212, 213, 214, 215, 222, 224, 225, 226, 247, 249, 258, 259, 261, 262, 263, 264, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 291, 293, 321, 322, 323, 324, 325, 326, 327, 328], "given": [1, 11, 13, 23, 60, 62, 64, 72, 74, 82, 84, 86, 87, 88, 89, 90, 91, 95, 96, 97, 101, 114, 126, 128, 130, 131, 136, 142, 150, 158, 166, 171, 173, 181, 189, 190, 191, 193, 203, 211, 212, 216, 225, 226, 232, 247, 272, 274, 280], "set": [1, 3, 4, 6, 75, 78, 165, 166, 179, 220, 224, 231, 233, 240, 241, 244, 245, 249, 252, 258, 272, 284, 291, 296, 300, 307, 318, 322, 324], "further": [1, 6, 322], "class": [1, 3, 4, 7, 8, 25, 203, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 273, 296, 299, 300, 301, 302, 303, 304, 305, 310, 311, 316], "under": [1, 114], "These": [1, 63, 184, 273, 328], "word": 1, "bit": [1, 72, 143, 144, 205, 228, 249], "abstract": 1, "back": [1, 3, 325], "give": [1, 3, 4, 23, 321], "ourselv": 1, "concret": [1, 224, 324, 328], "imag": [1, 215, 217, 218], "public": [1, 209], "explicit": [1, 307, 318, 325], "alpha_": 1, "beta_": 1, "must": [1, 6, 62, 101, 114, 145, 146, 150, 153, 154, 196, 325], "know": [1, 3], "popul": 1, "To": [1, 2, 3, 4, 6, 209, 293, 321, 322, 326], "avoid": [1, 321], "unnecessari": [1, 3], "alloc": [1, 296], "respons": 1, "space": [1, 116, 282], "void": 1, "eval_cpu": 1, "std": [1, 266], "overrid": [1, 78], "eval_gpu": 1, "jacobian": [1, 111, 194, 326], "product": [1, 106, 111, 127, 139, 142, 187, 194, 247, 326], "primal": [1, 111, 194], "tangent": [1, 19, 20, 111, 185, 186], "int": [1, 3, 4, 7, 11, 13, 14, 21, 22, 23, 24, 28, 29, 30, 31, 39, 40, 41, 42, 44, 47, 49, 52, 55, 56, 58, 60, 64, 65, 66, 72, 73, 74, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 101, 102, 105, 114, 116, 126, 128, 130, 131, 133, 137, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 156, 157, 158, 171, 172, 173, 176, 177, 181, 182, 183, 184, 187, 188, 189, 190, 191, 192, 193, 195, 197, 203, 209, 211, 212, 213, 214, 215, 219, 221, 222, 223, 224, 225, 226, 247, 249, 250, 252, 256, 259, 272, 273, 277, 282, 284, 296, 312, 314], "argnum": [1, 102, 192, 322], "cotan": 1, "across": [1, 221], "pair": [1, 140, 235, 252], "repres": [1, 3, 280, 284, 325], "axi": [1, 3, 4, 11, 13, 21, 22, 23, 24, 28, 29, 30, 31, 39, 40, 41, 42, 44, 52, 55, 58, 64, 74, 84, 86, 89, 92, 93, 94, 95, 96, 97, 98, 114, 126, 128, 130, 131, 133, 140, 141, 142, 146, 156, 171, 172, 173, 176, 177, 181, 182, 183, 184, 188, 193, 195, 211, 212, 225, 226, 272, 273, 277, 282, 284, 323], "correspond": [1, 11, 13, 56, 62, 72, 74, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 126, 128, 131, 142, 181, 187, 195, 201, 322], "dimens": [1, 3, 11, 13, 21, 22, 43, 49, 56, 66, 74, 84, 93, 94, 96, 97, 98, 106, 114, 115, 126, 127, 128, 130, 131, 142, 143, 146, 152, 181, 184, 187, 188, 193, 213, 214, 215, 217, 218, 221, 222, 223, 247, 250, 252, 259, 273, 321, 322], "vmap": [1, 322, 324, 326], "print": [1, 2, 3, 4, 6, 200, 201, 202, 209, 318, 321, 322, 323, 324, 325, 326], "ostream": 1, "os": [1, 6], "equival": [1, 26, 46, 57, 77, 100, 183, 220, 246, 248, 249, 257], "check": [1, 6, 59, 235, 322, 323], "bool": [1, 11, 12, 13, 21, 22, 28, 29, 30, 31, 39, 40, 41, 42, 44, 55, 56, 58, 59, 114, 117, 126, 128, 130, 131, 142, 144, 145, 150, 153, 154, 181, 193, 213, 214, 215, 221, 222, 223, 224, 228, 232, 233, 235, 241, 244, 247, 249, 252, 256, 259, 271, 274, 300, 311], "is_equival": 1, "privat": 1, "fall": 1, "eval": [1, 2, 3, 4, 209, 296, 298, 321, 322, 324, 326], "deriv": [1, 322, 324], "base": [1, 114, 119, 121, 252, 259, 296, 298, 304, 316, 318, 321, 323], "abov": [1, 3, 6, 143, 190, 209, 303, 322, 323, 324, 328], "demonstr": [1, 325], "treat": [1, 93, 94, 96, 97, 183, 321], "paramet": [1, 2, 3, 4, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 32, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 232, 233, 235, 240, 241, 244, 245, 246, 247, 248, 249, 250, 252, 254, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 291, 292, 293, 296, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 309, 310, 311, 312, 313, 314, 316, 321, 322, 324], "produc": [1, 63, 247, 293], "through": [1, 178, 259, 305, 321, 322, 325], "construct": [1, 4, 73, 101, 137, 197], "its": [1, 6, 127, 141, 152, 189, 199, 202, 209, 249, 302, 303, 304, 325, 328], "type": [1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 32, 49, 56, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 152, 153, 154, 155, 156, 157, 158, 159, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 200, 209, 259, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 321, 323], "shape": [1, 3, 4, 46, 59, 60, 65, 66, 74, 86, 89, 92, 95, 96, 97, 101, 111, 127, 137, 138, 145, 146, 147, 149, 150, 153, 154, 157, 184, 194, 196, 197, 198, 209, 211, 212, 213, 214, 215, 217, 218, 222, 224, 225, 226, 235, 260, 261, 262, 263, 264, 265, 266, 267, 273, 284, 298, 321, 322, 323, 326, 328], "pass": [1, 3, 4, 46, 57, 139, 140, 192, 199, 200, 201, 209, 233, 244, 245, 246, 249, 254, 321, 324], "re": [1, 4, 6, 293], "now": [1, 3, 6, 249, 321, 325], "promot": 1, "dtype": [1, 3, 14, 25, 32, 56, 85, 98, 101, 105, 114, 115, 116, 137, 147, 149, 150, 153, 154, 189, 197, 205, 260, 261, 262, 263, 264, 265, 266, 267, 271, 273, 280, 312, 313, 314, 321, 322, 323, 325, 326, 327], "promoted_dtyp": 1, "promote_typ": 1, "float32": [1, 14, 85, 105, 114, 115, 116, 137, 147, 149, 153, 154, 189, 197, 205, 260, 261, 262, 263, 264, 265, 266, 267, 271, 273, 280, 312, 313, 314, 321, 322, 323, 324, 325, 326, 327], "non": [1, 6, 227, 242, 285, 296], "point": [1, 2, 3, 6, 100, 144, 205], "out_dtyp": 1, "is_floating_point": 1, "cast": [1, 32, 95, 96, 97, 117, 228, 325], "up": [1, 3, 249, 321], "determin": [1, 74, 239, 327], "x_cast": 1, "astyp": [1, 3, 228, 325], "y_cast": 1, "broadcasted_input": 1, "broadcast_arrai": 1, "out_shap": 1, "0": [1, 2, 3, 4, 6, 7, 14, 47, 52, 58, 64, 65, 66, 73, 74, 85, 98, 102, 114, 115, 140, 145, 149, 154, 156, 158, 173, 177, 189, 190, 191, 192, 193, 195, 200, 209, 211, 212, 213, 214, 215, 216, 217, 218, 220, 221, 222, 223, 225, 226, 248, 251, 252, 253, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 269, 270, 271, 273, 275, 276, 280, 283, 284, 286, 287, 288, 290, 291, 293, 296, 299, 300, 302, 303, 304, 305, 307, 310, 311, 312, 313, 314, 318, 321, 322, 323, 324, 325, 326, 327], "unique_ptr": 1, "make_uniqu": 1, "to_stream": 1, "handl": [1, 209, 321], "resolv": 1, "No": [1, 3], "happen": [1, 3, 259, 298, 321, 324], "alon": [1, 325], "effect": [1, 217, 321, 324], "onli": [1, 3, 5, 6, 59, 65, 66, 114, 143, 205, 209, 232, 233, 235, 241, 244, 245, 246, 296, 321, 322, 327, 328], "execut": [1, 6, 325, 328], "depend": [1, 2, 56, 114, 323, 327, 328], "devic": [1, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 27, 28, 29, 30, 31, 32, 33, 35, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 149, 150, 152, 153, 154, 155, 156, 157, 158, 159, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 196, 197, 198, 203, 328, 329], "specifi": [1, 14, 32, 66, 74, 93, 94, 101, 102, 114, 116, 133, 137, 146, 156, 182, 183, 184, 187, 188, 192, 195, 197, 213, 258, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 291, 322, 328], "memori": [1, 5, 259, 296, 300, 321, 324, 325], "ha": [1, 3, 4, 5, 56, 63, 74, 92, 93, 95, 96, 97, 102, 146, 213, 224, 296, 298, 321, 323, 324, 326, 328], "been": [1, 3, 324], "try": [1, 6], "naiv": [1, 322], "gener": [1, 2, 14, 85, 93, 94, 116, 145, 149, 150, 153, 154, 259, 318, 321, 323, 324, 329], "version": [1, 6, 72, 122, 126, 143, 171, 195, 318, 322, 323], "declar": 1, "member": [1, 209, 238, 242], "method": [1, 3, 7, 8, 25, 203, 209, 239, 296, 299, 300, 301, 302, 303, 304, 305, 307, 310, 311, 316], "each": [1, 49, 72, 82, 127, 140, 143, 144, 146, 156, 163, 164, 173, 188, 195, 196, 217, 218, 219, 221, 252, 259, 271, 273, 318, 321, 324], "find": [1, 2, 6], "pointwis": 1, "captur": [1, 63, 209, 321], "templat": 1, "axpby_impl": 1, "typenam": 1, "t": [1, 3, 80, 144, 192, 209, 211, 225, 299, 300, 301, 302, 303, 304, 305, 310, 311, 321, 322, 328], "readi": 1, "fill": [1, 101, 138, 189, 198, 260, 261, 262, 263, 264, 266, 267], "malloc_or_wait": 1, "synchron": [1, 321], "avail": [1, 2, 3, 4, 6, 8, 205, 328], "There": [1, 209, 321], "wait": [1, 3], "here": [1, 3, 321, 322, 324, 327, 328], "request": 1, "pressur": 1, "condit": [1, 196, 328], "set_data": 1, "nbyte": 1, "collect": [1, 201, 320], "pointer": 1, "x_ptr": 1, "y_ptr": 1, "out_ptr": 1, "relev": 1, "static_cast": 1, "size_t": 1, "out_idx": 1, "size": [1, 3, 4, 49, 66, 72, 84, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 101, 105, 114, 143, 144, 146, 157, 173, 176, 209, 211, 212, 214, 215, 219, 222, 225, 226, 249, 300, 324, 325], "map": [1, 4, 117, 201, 219, 228], "linear": [1, 3, 4, 5, 201, 209, 220, 235, 249, 251, 253, 255, 268, 269, 270, 287, 288, 289, 293, 296, 307, 321], "indic": [1, 12, 21, 22, 23, 24, 102, 107, 108, 109, 110, 173, 183, 184, 192, 241, 243, 273, 280, 323], "offset": [1, 3, 74], "x_offset": 1, "elem_to_loc": 1, "stride": [1, 65, 66, 211, 212, 214, 215, 225, 226, 252, 323], "y_offset": 1, "contigu": 1, "regularli": 1, "default": [1, 6, 11, 12, 13, 14, 21, 22, 23, 24, 59, 63, 64, 65, 66, 70, 71, 72, 73, 74, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 102, 105, 114, 115, 116, 117, 126, 128, 130, 131, 137, 141, 142, 143, 144, 145, 146, 147, 149, 150, 152, 153, 154, 156, 157, 158, 165, 166, 172, 173, 176, 177, 179, 181, 187, 188, 189, 190, 191, 192, 193, 195, 197, 205, 211, 212, 213, 214, 215, 222, 224, 225, 226, 228, 233, 235, 241, 244, 247, 248, 249, 252, 256, 257, 259, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 296, 299, 300, 301, 302, 303, 304, 305, 310, 311, 318, 320, 321, 322, 325, 327, 329], "row": [1, 85, 105, 143, 189], "major": 1, "henc": [1, 143, 321], "doesn": [1, 209], "addit": [1, 3, 10, 117, 213, 221, 223, 247, 250, 296, 322], "abl": [1, 143], "all": [1, 4, 6, 12, 23, 66, 85, 88, 91, 94, 97, 127, 140, 141, 176, 209, 228, 229, 233, 236, 237, 238, 242, 244, 247, 249, 256, 259, 293, 296, 316, 318, 321, 323, 324, 326, 329], "incom": 1, "accordingli": 1, "dispatch": 1, "float16": [1, 117, 205, 228, 324, 325], "bfloat16": [1, 325], "complex64": 1, "throw": 1, "error": [1, 6, 80, 81, 173, 220, 249, 268, 269, 270, 279, 281, 322, 325], "encount": [1, 322], "unexpect": [1, 14], "regist": [1, 4], "op": [1, 139, 233, 324], "assert": 1, "2": [1, 2, 3, 4, 66, 73, 74, 80, 87, 90, 92, 93, 94, 95, 96, 97, 98, 114, 115, 121, 127, 143, 152, 187, 189, 190, 191, 205, 209, 211, 212, 215, 220, 225, 226, 250, 256, 260, 261, 262, 263, 264, 265, 266, 267, 269, 273, 274, 276, 283, 284, 293, 296, 299, 301, 302, 303, 307, 310, 321, 322, 323, 324, 325, 326, 327, 328], "1": [1, 3, 4, 14, 23, 24, 65, 66, 73, 74, 86, 87, 89, 90, 92, 93, 94, 95, 96, 97, 98, 106, 114, 115, 127, 139, 141, 143, 146, 149, 154, 167, 172, 183, 192, 205, 209, 211, 212, 213, 214, 215, 216, 217, 218, 220, 221, 222, 223, 224, 225, 226, 248, 250, 252, 253, 256, 258, 261, 262, 263, 264, 265, 266, 267, 269, 270, 271, 272, 273, 274, 275, 276, 277, 279, 280, 282, 283, 284, 288, 291, 293, 296, 298, 299, 300, 301, 302, 303, 304, 305, 307, 310, 311, 312, 313, 314, 321, 322, 323, 325, 326, 327, 328], "correct": [1, 6, 302, 303, 304, 323, 324], "els": [1, 3, 209, 233, 324], "float16_t": 1, "bfloat16_t": 1, "complex64_t": 1, "runtime_error": 1, "support": [1, 3, 5, 6, 12, 65, 66, 98, 115, 117, 127, 143, 322, 323, 325, 327], "have": [1, 3, 6, 12, 59, 93, 94, 96, 97, 127, 146, 200, 247, 254, 305, 307, 320, 321, 323, 324, 328], "rememb": 1, "3": [1, 3, 6, 98, 114, 115, 262, 264, 300, 305, 318, 321, 323, 325, 326], "complic": 1, "keep": [1, 11, 13, 21, 22, 126, 128, 130, 131, 142, 181, 193, 209, 232, 322, 324], "mind": [1, 3], "half": [1, 14, 150, 154, 252, 324], "precis": [1, 3, 209, 220, 306, 321], "direct": [1, 3, 230, 305, 328], "fix": [1, 3, 6, 324], "possibl": [1, 3, 127, 173, 219, 321, 323, 328], "due": 1, "transpos": [1, 3, 26, 144], "aren": 1, "guarante": 1, "fit": [1, 143, 328], "requir": [1, 3, 209, 324, 325], "column": [1, 85, 105, 143], "inplac": 1, "expect": [1, 3, 214, 215, 216, 217, 218, 256, 259, 274, 321, 323], "answer": 1, "copi": [1, 3, 5, 141, 172, 325], "simpli": [1, 3, 6, 251, 287, 296, 321, 322], "catlas_saxpbi": 1, "axpby_impl_acceler": 1, "first": [1, 2, 3, 4, 6, 74, 98, 102, 123, 125, 127, 141, 152, 182, 187, 192, 200, 209, 212, 221, 226, 272, 280, 300, 302, 303, 304, 307, 321, 322, 325, 328], "mode": [1, 67, 231, 241, 243, 263, 264], "i": [1, 3, 111, 114, 209, 214, 215, 217, 218, 233, 279, 303, 321, 322], "e": [1, 4, 6, 80, 111, 167, 213, 214, 215, 217, 218, 221, 222, 223, 233, 250, 292, 298, 301, 321, 324, 329], "match": [1, 6, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 235, 273, 323, 325], "transposit": 1, "data_s": 1, "items": 1, "flag": [1, 321, 325], "copy_inplac": 1, "copytyp": 1, "n": [1, 3, 25, 65, 66, 85, 86, 88, 89, 91, 92, 95, 97, 105, 189, 193, 211, 212, 213, 214, 215, 217, 218, 225, 226, 279, 284], "incx": 1, "inci": 1, "great": 1, "But": [1, 328], "criteria": 1, "luckili": [1, 324], "alwai": [1, 200, 322], "With": 1, "final": [1, 2, 3, 4], "singl": [1, 4, 82, 111, 117, 140, 194, 212, 226, 321, 323, 327], "row_contigu": 1, "col_contigu": 1, "common": [1, 298, 321, 324], "hit": 1, "mileston": 1, "enough": [1, 324], "run": [1, 3, 4, 5, 6, 7, 139, 203, 213, 228, 299, 300, 302, 303, 304, 321, 324, 328, 329], "If": [1, 3, 6, 11, 12, 13, 14, 21, 22, 23, 24, 56, 59, 62, 64, 67, 73, 74, 82, 95, 96, 97, 100, 101, 102, 114, 117, 126, 127, 128, 130, 131, 137, 140, 141, 142, 146, 156, 171, 172, 173, 181, 183, 184, 187, 192, 193, 195, 197, 201, 213, 214, 215, 221, 223, 224, 233, 235, 244, 249, 252, 254, 256, 271, 273, 284, 300, 321, 322, 324, 327, 328, 329], "plan": [1, 321], "stop": [1, 3, 14, 116, 178, 322, 323], "enjoi": 1, "speed": 1, "appl": [1, 3, 5, 6, 328], "silicon": [1, 3, 5, 6, 328], "address": 1, "shade": 1, "languag": [1, 205], "kernel": [1, 65, 66, 211, 225, 321, 323], "written": 1, "help": [1, 3, 321, 328], "resourc": 1, "walkthrough": 1, "pipelin": 1, "specif": [1, 6, 322], "cpp": 1, "algorithm": [1, 305], "launch": [1, 323], "exactli": [1, 3, 235, 322], "mani": [1, 173, 214, 215, 219, 321, 324], "thread": 1, "pick": 1, "updat": [1, 2, 3, 4, 63, 201, 213, 228, 235, 240, 246, 298, 300, 303, 305, 306, 307, 311, 312, 313, 314, 321, 324], "assign": [1, 296], "axpby_gener": 1, "buffer": [1, 325], "constant": [1, 3, 6, 140, 209, 213, 221, 223, 250, 274, 284, 310, 312, 321, 325], "4": [1, 3, 72, 98, 114, 143, 144, 163, 205, 211, 212, 213, 222, 225, 226, 249, 259, 261, 262, 263, 271, 321, 323, 326, 328], "5": [1, 2, 3, 6, 114, 145, 211, 213, 216, 217, 218, 222, 225, 257, 260, 263, 264, 283, 290, 293, 310, 312, 313, 321, 322, 323], "x_stride": 1, "6": [1, 3, 114, 163, 259, 262, 269, 270, 274, 284, 310, 321, 323, 326], "y_stride": 1, "7": [1, 3, 114, 143, 323], "ndim": [1, 98, 114], "8": [1, 3, 6, 114, 143, 205, 212, 222, 226, 259, 272, 299, 300, 301, 302, 303, 304, 310, 321, 323, 326, 328], "uint": 1, "index": [1, 5, 7, 23, 84, 85, 102, 141, 183, 184, 192, 203], "thread_position_in_grid": 1, "convert": [1, 56, 98, 249, 324, 325, 326], "instanti": [1, 4, 324], "uniqu": [1, 318], "host": 1, "name": [1, 117, 143, 144, 161, 162, 163, 164, 209, 221, 232, 235, 237, 323, 327], "identifi": [1, 200, 320], "instantiate_axpbi": 1, "type_nam": 1, "host_nam": 1, "axpby_general_": 1, "compil": [1, 5, 6, 75, 78, 322, 324], "mlx_ext": 1, "metallib": [1, 6], "see": [1, 3, 4, 6, 8, 27, 28, 29, 30, 31, 33, 35, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 114, 161, 162, 209, 213, 217, 220, 231, 248, 249, 252, 253, 256, 257, 261, 262, 263, 264, 268, 269, 270, 288, 321, 322, 323, 326, 328], "later": [1, 6], "co": [1, 256, 322], "locat": [1, 245, 246, 328], "share": [1, 5, 72, 143, 144], "register_librari": 1, "potenti": 1, "path": [1, 6, 163, 164, 235], "tri": 1, "load": [1, 4, 5, 235], "hasn": 1, "alreadi": [1, 3], "static": [1, 6], "object": [1, 8, 25, 36, 56, 145, 150, 153, 154, 195, 200, 201, 217, 320], "why": [1, 3], "packag": [1, 2, 4, 293], "process": [1, 3, 67, 201, 218, 219, 259, 320], "logic": [1, 123, 124, 125], "grid": 1, "shown": 1, "below": [1, 6, 114, 189, 191, 205, 324], "prepar": [1, 3], "carri": 1, "should": [1, 2, 3, 4, 6, 74, 111, 143, 184, 192, 194, 200, 209, 214, 215, 217, 218, 241, 247, 254, 273, 275, 280, 296, 320, 321, 322, 324, 325, 329], "d": [1, 3, 73, 74, 106, 114, 127, 139, 183, 189, 190, 191, 202, 218, 299, 302, 304, 328], "ostringstream": 1, "kname": 1, "axpby_": 1, "general_": 1, "type_to_nam": 1, "make": [1, 3, 4, 6, 127, 136, 166, 209, 312, 313, 314, 321, 324, 326, 328], "sure": [1, 3, 6, 209, 321], "look": [1, 3], "folder": 1, "get_colocated_mtllib_path": 1, "get_kernel": 1, "str": [1, 67, 102, 114, 117, 160, 161, 162, 163, 164, 192, 200, 202, 228, 229, 232, 233, 235, 237, 239, 244, 263, 264, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284], "encod": [1, 252, 256, 259, 273], "compute_encod": 1, "get_command_encod": 1, "setcomputepipelinest": 1, "those": [1, 3, 209], "nelem": 1, "set_array_buff": 1, "setbyt": 1, "sizeof": 1, "threadgroup": 1, "higher": [1, 106, 280, 322], "than": [1, 3, 56, 67, 74, 77, 103, 104, 112, 113, 127, 201, 252, 258, 280, 283, 291, 300, 305, 321, 322, 328], "max": [1, 114, 129, 225, 226, 248, 272, 274, 275, 280, 284, 286, 300, 304, 321, 322, 328], "allow": [1, 209, 246, 296, 316, 323, 326], "tgp_size": 1, "min": [1, 114, 132, 248, 286], "maxtotalthreadsperthreadgroup": 1, "3d": [1, 213, 218], "mtl": 1, "group_dim": 1, "grid_dim": 1, "divid": [1, 100, 143], "among": 1, "dispatchthread": 1, "few": [1, 3, 4, 5, 324, 326], "thing": [1, 3], "note": [1, 3, 6, 12, 65, 66, 93, 94, 114, 143, 146, 209, 325, 327], "befor": [1, 3, 6, 23, 141, 232, 259, 307, 323, 324], "move": [1, 133, 328], "track": [1, 209, 213], "activ": [1, 6, 217, 227, 258, 259, 285, 290, 291, 292, 321], "command": [1, 6], "instead": [1, 6, 209, 246, 256, 322, 324], "end_encod": 1, "end": [1, 74, 143, 212, 226, 253, 258, 276, 283, 288, 290, 291], "until": [1, 324, 326], "limit": [1, 62, 323], "flush": 1, "enqueu": 1, "commit": 1, "associ": [1, 163, 164, 324], "suggest": 1, "deeper": 1, "dive": 1, "studi": 1, "come": [1, 3, 322], "far": [1, 298], "built": [1, 6, 324], "includ": [1, 229, 240, 249, 274, 321, 322, 323, 326, 327, 329], "forward": [1, 192, 321, 324], "diff": 1, "push": 1, "along": [1, 21, 22, 63, 64, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 114, 156, 171, 173, 177, 183, 184, 187, 209], "similarli": [1, 6, 127, 322, 324], "scale_arr": 1, "contribut": 1, "tangent_x": 1, "tangent_i": 1, "revers": [1, 188, 256], "arg": [1, 3, 8, 46, 57, 82, 163, 164], "push_back": 1, "fulli": [1, 5, 321, 325, 328], "overal": 1, "directori": [1, 3, 6], "extens": [1, 117, 205, 239, 327], "h": [1, 65, 66, 114, 212, 213, 215, 217, 218, 226, 322, 324], "mlx_sample_extens": 1, "__init__": [1, 3, 4, 7, 8, 25, 203, 209, 296], "py": [1, 3, 6], "cmakelist": 1, "txt": 1, "setup": [1, 2, 4, 6, 321], "hold": [1, 3, 8, 114, 321], "instal": 1, "pybind11": [1, 6], "sinc": [1, 3, 4, 296, 305, 325, 328], "compon": [1, 3], "etc": [1, 143, 209], "pybind11_modul": 1, "m": [1, 6, 85, 114, 189, 211, 212, 225, 226, 299], "doc": [1, 4], "sampl": [1, 2, 3, 116, 145, 146, 147, 150, 153, 154, 261, 262, 263, 264, 266, 267, 274, 280, 284, 318, 321], "_a": 1, "pos_onli": 1, "kw_onli": 1, "none": [1, 3, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 35, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 162, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 195, 196, 197, 198, 200, 201, 203, 211, 212, 220, 225, 226, 228, 232, 233, 244, 247, 256, 259, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 300, 316, 323], "r": [1, 3, 115, 192, 217], "pbdoc": 1, "most": [1, 146, 209, 309, 321, 322, 323, 324], "complex": [1, 93, 94, 95, 96, 97, 145, 150, 153, 154, 200, 209, 246, 321, 322], "bell": 1, "whistl": 1, "liter": [1, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284], "string": [1, 325, 327], "modul": [1, 3, 4, 199, 249, 254, 259, 293, 309, 320, 321, 324], "ensur": [1, 6, 279], "caster": 1, "find_packag": 1, "config": 1, "add_librari": 1, "sourc": [1, 133, 188], "target_sourc": 1, "cmake_current_list_dir": 1, "header": 1, "target_include_directori": 1, "target_link_librari": 1, "attach": 1, "conveni": [1, 4], "mlx_build_metallib": 1, "target": [1, 192, 271, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 321], "destin": [1, 133], "automat": [1, 5, 117, 326, 327, 328], "practic": [1, 321], "mlx_build_met": [1, 6], "mlx_ext_metallib": 1, "titl": 1, "include_dir": 1, "project_source_dir": 1, "mlx_include_dir": 1, "output_directori": 1, "cmake_library_output_directori": 1, "add_depend": 1, "endif": 1, "pybind11_add_modul": 1, "build_shared_lib": 1, "target_link_opt": 1, "wl": 1, "rpath": 1, "loader_path": 1, "onc": [1, 321], "describ": [1, 324], "util": [1, 3, 5, 6, 163, 209], "__name__": [1, 3], "__main__": [1, 3], "descript": [1, 3, 205], "ext_modul": 1, "cmakeextens": 1, "cmdclass": 1, "build_ext": 1, "cmakebuild": 1, "package_dir": 1, "package_data": 1, "dylib": 1, "zip_saf": 1, "fals": [1, 3, 11, 12, 13, 21, 22, 28, 29, 30, 31, 39, 40, 41, 42, 44, 55, 58, 59, 114, 117, 126, 128, 130, 131, 142, 181, 193, 196, 200, 201, 205, 221, 222, 224, 233, 235, 244, 247, 249, 252, 256, 259, 271, 274, 300, 311, 325], "python_requir": 1, "even": [1, 3, 321, 324, 325], "though": [1, 3, 321, 324, 325], "j8": 1, "libmlx_ext": 1, "cpython": 1, "3x": 1, "darwin": 1, "pip": [1, 6], "after": [1, 3, 4, 23, 98, 100, 141, 143, 213, 221, 223, 247, 259, 283, 321, 328], "plai": [1, 3], "ones": [1, 3, 138, 163, 189, 245, 246, 249, 323], "b": [1, 3, 10, 12, 59, 76, 77, 79, 100, 103, 104, 106, 112, 113, 114, 122, 123, 125, 127, 129, 132, 134, 139, 143, 180, 187, 192, 224, 322, 323, 324, 325, 326, 327, 328], "f": [1, 2, 4, 114, 209, 303, 321, 325], "item": [1, 2, 3, 4, 201, 324, 325, 326], "true": [1, 2, 3, 12, 59, 114, 117, 144, 171, 196, 200, 201, 205, 209, 213, 214, 215, 221, 222, 223, 224, 232, 233, 235, 241, 244, 249, 252, 256, 259, 271, 279, 300], "quick": [1, 5], "benchmark": [1, 321], "compar": [1, 59, 321], "time": [1, 3, 6, 209, 211, 212, 225, 226, 321, 322, 324, 328], "set_default_devic": 1, "256": [1, 4], "512": [1, 3, 259, 328], "random": [1, 2, 3, 4, 5, 211, 212, 213, 222, 225, 226, 235, 241, 321, 322, 328, 329], "normal": [1, 2, 3, 153, 209, 211, 212, 213, 221, 222, 223, 225, 226, 250, 259, 261, 263, 325, 328], "bench": 1, "warm": [1, 321], "rang": [1, 2, 3, 4, 6, 14, 98, 116, 262, 264, 269, 270, 298, 312, 313, 314, 318, 321, 322, 324, 328], "100": [1, 2, 3, 321, 322, 324, 328], "5000": 1, "simple_tim": 1, "custom_tim": 1, "3f": [1, 4, 321], "custom": [1, 259], "114": 1, "109": 1, "modest": 1, "improv": [1, 3, 299, 300, 301, 302, 303, 304, 310, 321], "awai": [1, 3], "good": [1, 6, 321, 328], "nn": [1, 3, 4, 163, 201, 209, 293, 296, 298, 307, 309, 321, 324], "grad": [1, 2, 4, 192, 298, 306, 321, 322, 323, 324, 326], "full": [1, 4, 46, 57, 67, 171, 245, 246, 274, 321, 324], "implement": [2, 4, 114, 219, 232, 247, 252, 254, 256, 258, 259, 291, 299, 300, 301, 302, 304, 305, 306, 316, 321, 322, 325], "basic": [2, 158, 322], "model": [2, 4, 5, 163, 199, 201, 209, 228, 231, 233, 235, 239, 241, 243, 244, 245, 247, 259, 293, 296, 298, 306, 307, 309, 321, 324], "problem": [2, 4, 209], "metadata": [2, 117, 161, 162], "num_featur": [2, 213], "num_exampl": 2, "1_000": 2, "num_it": 2, "10_000": 2, "iter": [2, 4, 201, 318, 321, 324], "sgd": [2, 4, 298, 305, 307, 312, 313, 314, 321], "lr": [2, 305], "01": [2, 303], "rate": [2, 299, 300, 301, 302, 303, 304, 305, 310, 311], "ll": [2, 4, 276, 321, 322], "synthet": 2, "dataset": [2, 324], "matrix": [2, 72, 73, 85, 105, 114, 115, 127, 143, 144, 249, 265, 293], "ground": [2, 3, 273, 283], "truth": [2, 273, 283], "w_star": 2, "valu": [2, 3, 9, 12, 14, 21, 22, 36, 56, 59, 62, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 101, 114, 116, 140, 145, 146, 147, 149, 150, 153, 154, 161, 183, 184, 192, 195, 199, 200, 201, 205, 212, 216, 217, 218, 222, 224, 226, 232, 247, 248, 257, 258, 259, 260, 271, 272, 273, 274, 275, 276, 278, 279, 280, 281, 282, 283, 291, 296, 300, 303, 312, 313, 314, 322], "gaussian": [2, 220, 268, 269, 270, 274], "nois": 2, "exampl": [2, 3, 4, 14, 98, 114, 115, 179, 183, 209, 211, 212, 213, 222, 225, 226, 233, 235, 241, 244, 260, 261, 262, 263, 264, 265, 266, 267, 271, 273, 280, 293, 298, 307, 312, 313, 314, 318, 322, 323, 324, 325, 326, 327], "noisi": 2, "label": [2, 273, 280], "ep": [2, 213, 221, 222, 223, 250, 272, 274, 284, 299, 300, 301, 302, 303, 304, 310], "1e": [2, 4, 12, 213, 221, 222, 223, 250, 272, 274, 284, 299, 300, 301, 302, 303, 304, 307, 310, 312, 313, 314], "us": [2, 3, 4, 5, 6, 14, 72, 75, 77, 98, 114, 115, 127, 143, 144, 156, 157, 200, 209, 212, 217, 219, 220, 224, 226, 228, 232, 239, 245, 246, 247, 249, 252, 256, 259, 263, 264, 269, 270, 272, 293, 296, 298, 299, 300, 302, 303, 304, 305, 306, 307, 316, 318, 320, 321, 322, 323, 326, 328], "weight": [2, 65, 66, 201, 209, 235, 239, 249, 271, 273, 296, 300, 303, 305, 307, 311, 322, 324], "squar": [2, 3, 105, 159, 174, 192, 201, 209, 250, 281, 283, 299, 300, 302, 303, 304, 322, 325], "loss": [2, 4, 192, 209, 298, 321, 322, 324], "loss_fn": [2, 4, 298, 321, 322], "w": [2, 66, 72, 143, 144, 192, 212, 213, 215, 217, 218, 224, 226, 311, 322], "mean": [2, 3, 4, 149, 192, 209, 213, 221, 233, 250, 266, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 321, 322, 325], "grad_fn": [2, 321, 322], "initi": [2, 3, 209, 213, 221, 222, 223, 224, 248, 250, 260, 261, 262, 263, 264, 265, 266, 267, 296, 307, 312, 313, 314, 321, 324], "randomli": [2, 3, 216, 217, 218], "Then": [2, 6], "repeatedli": 2, "_": [2, 3, 209, 312, 313, 314, 318, 321, 324, 328], "verifi": [2, 6], "close": [2, 5, 6, 12], "error_norm": 2, "5f": 2, "someth": [2, 3, 323], "00005": 2, "00364": 2, "complet": [2, 3, 6, 245, 246, 322, 328], "logist": [2, 167, 255, 269, 270, 289], "github": [2, 4, 6, 321], "repo": [2, 4, 6, 321], "enabl": [3, 6, 78, 311], "larg": [3, 209, 247, 279, 321, 324], "ish": 3, "transform": [3, 5, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 199, 209, 213, 221, 223, 224, 232, 233, 244, 249, 252, 323], "compromis": 3, "eas": 3, "llama": 3, "famili": 3, "less": [3, 23, 113, 141, 252, 283], "200": 3, "line": [3, 324, 325], "python": [3, 36, 49, 56, 82, 200, 201, 202, 296, 306, 307, 309, 320, 322, 325], "neural": [3, 5, 219, 227, 261, 262, 285, 293, 296, 310], "network": [3, 5, 213, 217, 219, 261, 262, 293, 296, 310], "build": [3, 5, 263, 296, 321], "concis": 3, "architectur": [3, 6, 209, 246, 328], "notabl": [3, 5], "rope": [3, 209], "posit": [3, 23, 74, 98, 102, 110, 133, 141, 192, 201, 209, 214, 215, 247, 252, 256, 274, 284], "option": [3, 11, 13, 14, 21, 22, 23, 24, 25, 30, 31, 63, 64, 65, 66, 67, 72, 73, 74, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 101, 102, 105, 109, 110, 114, 115, 116, 117, 126, 128, 130, 131, 137, 140, 141, 142, 143, 144, 145, 146, 147, 149, 150, 152, 153, 154, 156, 157, 162, 171, 172, 173, 176, 177, 181, 183, 184, 187, 188, 189, 190, 191, 192, 193, 195, 197, 200, 201, 211, 212, 213, 214, 215, 224, 225, 226, 228, 232, 233, 235, 244, 247, 249, 252, 256, 259, 260, 261, 262, 263, 264, 265, 266, 267, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 299, 300, 301, 302, 303, 304, 305, 307, 310, 311, 318, 321, 327, 329], "kei": [3, 145, 146, 147, 149, 150, 152, 153, 154, 200, 201, 232, 233, 244, 247, 307, 318, 320, 322], "cach": [3, 321], "concaten": 3, "project": [3, 247], "llamaattent": 3, "self": [3, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 52, 53, 54, 55, 56, 57, 58, 203, 209, 227, 285, 296], "dim": [3, 219, 221, 222, 223, 247, 250, 252, 256, 259], "num_head": [3, 247, 259], "super": [3, 4, 209, 296], "tradit": [3, 217, 218, 252], "query_proj": 3, "bia": [3, 72, 143, 144, 201, 209, 214, 215, 224, 233, 235, 244, 247, 249, 302, 303, 304, 307, 322], "key_proj": 3, "value_proj": 3, "out_proj": [3, 296], "__call__": [3, 4, 209, 296], "queri": [3, 247], "mask": [3, 241, 247, 323], "extract": [3, 73, 74, 209, 232, 296], "l": [3, 4, 209, 211, 213, 214, 225, 283], "reshap": [3, 114, 323], "combin": 3, "key_cach": 3, "value_cach": 3, "sqrt": [3, 80, 213, 221, 222, 223, 224, 250, 256, 261, 262, 263, 264, 299, 301, 302, 303, 310, 321], "score": [3, 280], "softmax": [3, 273], "values_hat": 3, "rm": [3, 6, 300], "swiglu": 3, "rmsnorm": [3, 209], "llamaencoderlay": 3, "mlp_dim": [3, 259], "norm1": 3, "norm2": 3, "linear1": 3, "linear2": 3, "linear3": 3, "sigmoid": [3, 255, 269, 270, 289], "instanc": [3, 143, 202, 209, 222, 228, 229, 230, 233, 236, 237, 244, 246, 254, 296, 325], "embed": [3, 209, 252, 256, 272], "emb": [3, 219, 256], "token": [3, 219], "num_lay": [3, 4, 298], "vocab_s": 3, "norm": [3, 221, 284, 304, 305], "multiheadattent": [3, 209], "create_additive_causal_mask": 3, "list": [3, 8, 11, 13, 25, 28, 29, 39, 40, 41, 42, 44, 52, 55, 56, 58, 60, 63, 64, 82, 84, 87, 88, 90, 91, 93, 94, 96, 97, 101, 102, 111, 114, 126, 128, 130, 131, 137, 140, 142, 145, 146, 147, 149, 150, 153, 154, 157, 161, 171, 173, 176, 177, 181, 187, 188, 192, 193, 194, 197, 200, 202, 209, 233, 235, 236, 237, 238, 242, 244, 245, 246, 296, 302, 303, 304, 305, 320, 321, 322, 324], "still": [3, 6, 114, 321, 324], "consid": [3, 12, 59, 200, 201, 221, 320], "train": [3, 4, 209, 213, 216, 217, 218, 231, 233, 244, 261, 262], "ignor": [3, 62, 63, 82, 300], "whatsoev": 3, "rest": [3, 201, 252], "subsect": 3, "prompt": 3, "autoregress": 3, "yield": [3, 4, 318], "temp": 3, "causal": 3, "save": [3, 5, 117, 143, 161, 162, 163, 164, 239, 324], "append": [3, 127, 321, 324], "store": 3, "per": [3, 4, 72, 143, 144, 213, 221, 222, 223, 250, 316, 321, 324], "care": [3, 324], "last": [3, 24, 56, 88, 91, 93, 94, 96, 97, 98, 106, 115, 127, 146, 172, 187, 214, 215, 217, 218, 221, 325], "logit": [3, 146, 271, 273, 321], "next": [3, 4], "categor": 3, "lazili": [3, 209], "noth": [3, 209, 324], "yet": [3, 114, 209, 296, 307, 322, 323, 324, 326], "forc": [3, 4, 209, 326], "choos": [3, 252], "pars": 3, "feed": 3, "loop": [3, 4, 321, 322, 324], "unsqueez": 3, "sequenc": [3, 213, 214, 259, 318, 328], "length": [3, 176, 213, 214], "len": [3, 88, 91, 94, 97], "overwrit": 3, "discard": [3, 200], "old": 3, "moment": [3, 300, 302, 303, 304], "anymor": 3, "everyth": 3, "small": [3, 213, 221, 223, 250, 274, 279, 284, 321, 328], "10": [3, 4, 119, 158, 163, 201, 209, 235, 293, 314, 321, 323], "12": 3, "8192": 3, "1024": 3, "actual": [3, 14, 235, 296, 324], "materi": [3, 5], "could": [3, 209], "20_000": 3, "machin": [3, 5, 6, 310], "8gb": 3, "ram": 3, "32": [3, 4, 143, 144, 205, 212, 226, 321], "44": 3, "doubl": 3, "bracket": 3, "becaus": [3, 209, 324], "batch": [3, 127, 213, 214, 215, 217, 218, 247, 324], "zip": [3, 4], "haven": 3, "anyth": [3, 192, 324], "result": [3, 14, 56, 72, 106, 114, 117, 127, 139, 144, 156, 158, 177, 187, 196, 201, 256, 321, 322, 325], "similar": [3, 201, 245, 246, 247, 272, 325, 327], "runtim": [3, 321], "section": [3, 6, 173, 284, 321, 322], "access": [3, 36, 209, 296, 307, 324, 328], "origin": [3, 74, 213, 240, 261, 262, 263, 264, 299, 300, 301, 302, 304, 305, 325], "sentencepiec": 3, "pytorch": [3, 5, 221, 322], "compat": [3, 146, 327], "npz": [3, 117, 163, 164, 235, 239, 327], "file": [3, 6, 117, 160, 161, 162, 163, 164, 235, 239, 322, 327], "directli": 3, "argpars": 3, "itertool": [3, 201], "starmap": [3, 201], "np": [3, 4, 325, 326], "torch": [3, 325], "map_torch_to_mlx": 3, "tok_embed": 3, "elif": 3, "replac": [3, 245, 246, 259, 283], "attention_norm": 3, "ffn_norm": 3, "wq": 3, "wk": 3, "wv": 3, "wo": 3, "w1": 3, "w2": 3, "w3": 3, "ffn": 3, "separ": [3, 46, 57, 221, 280], "submodul": [3, 4, 209, 233, 234, 244, 246], "feed_forward": 3, "parser": 3, "argumentpars": 3, "add_argu": 3, "torch_weight": 3, "output_fil": 3, "parse_arg": 3, "state": [3, 4, 209, 298, 307, 318, 321], "savez": [3, 239, 327], "k": [3, 73, 85, 189, 190, 191, 211, 224, 225, 233], "v": [3, 67, 209, 233, 325], "left": [3, 114, 143, 211, 212, 220, 225, 226, 252, 269, 270, 274, 276, 284], "disk": 3, "text": [3, 211, 212, 225, 226, 227, 253, 258, 261, 262, 263, 264, 274, 275, 276, 279, 280, 283, 285, 286, 288, 290, 291, 300, 305], "format": [3, 117, 160, 161, 162, 163, 164, 325], "oper": [3, 5, 7, 32, 171, 178, 184, 203, 209, 259, 305, 321, 322, 323, 324, 325, 326, 328, 329], "dictionari": [3, 63, 117, 161, 162, 200, 209, 232, 240, 245, 246, 308, 320, 327], "represent": [3, 143, 200, 202], "tree_unflatten": 3, "helper": [3, 321], "weight_fil": 3, "incur": 3, "sever": [3, 65, 66, 163, 164, 321, 327], "futur": [3, 249, 323, 324], "pth": 3, "current": [3, 5, 6, 65, 66, 143, 209, 300, 324], "around": 3, "m1": [3, 321, 322, 328], "ultra": 3, "7b": 3, "me": 3, "ishmael": 3, "year": 3, "ago": 3, "never": [3, 324], "long": 3, "info": [3, 6], "247": 3, "press": [3, 114], "enter": 3, "littl": 3, "monei": 3, "my": [3, 6], "purs": 3, "greater": [3, 23, 104, 141, 258, 291], "consequ": 3, "walk": 3, "down": 3, "gower": 3, "street": 3, "afternoon": 3, "heavi": 3, "rain": 3, "saw": [3, 322], "off": [3, 6, 324], "man": 3, "rag": 3, "who": 3, "sat": 3, "upon": [3, 201], "hi": 3, "bundl": 3, "hard": 3, "wet": 3, "he": [3, 263, 264], "were": [3, 328], "cry": 3, "watch": [3, 321], "him": 3, "observ": 3, "numer": [3, 114, 122, 126, 171, 213, 221, 222, 223, 250, 272, 274, 284, 299, 300, 301, 302, 303, 304, 310, 321, 324], "crowd": 3, "wa": [3, 324], "hurri": 3, "437": 3, "330": 3, "second": [3, 74, 123, 125, 127, 182, 192, 212, 226, 272, 280, 300, 302, 303, 304, 322, 328], "spent": 3, "amount": [3, 211, 225], "39": 3, "ms": [3, 321], "By": [3, 322, 325], "bigger": [3, 300], "remain": [3, 192, 216, 217, 218], "almost": 3, "nobodi": 3, "took": 3, "least": [3, 62, 115, 143], "notic": [3, 322, 327], "distanc": [3, 284], "had": 3, "doubt": 3, "minut": 3, "straight": 3, "slowli": 3, "rais": [3, 114, 173, 235], "ey": 3, "speak": [3, 114], "resum": 3, "postur": 3, "stood": 3, "feel": 3, "pain": 3, "heart": 3, "smile": 3, "face": 3, "am": 3, "someon": 3, "three": 3, "quarter": 3, "hour": 3, "made": 3, "immedi": [3, 228], "repli": 3, "again": [3, 6, 209, 321], "hand": [3, 322, 324], "did": 3, "accustom": 3, "thu": [3, 209], "question": [3, 324], "reason": [3, 323], "tell": [3, 321, 325], "understand": [3, 261, 262], "579": 3, "690": 3, "num": [3, 116, 152], "500": [3, 328], "628": 3, "went": 3, "nervou": 3, "trembl": 3, "told": 3, "And": 3, "perhap": 3, "surpris": 3, "matter": [3, 209], "shall": 3, "anyhow": 3, "friend": 3, "ye": 3, "slight": [3, 324], "kind": 3, "longer": [3, 67, 322], "soon": 3, "unless": [3, 12, 114, 296], "unlik": [3, 12, 217, 218, 240], "strang": 3, "amus": 3, "That": 3, "secret": 3, "disappoint": 3, "mine": 3, "cannot": [3, 62, 323, 325], "happi": 3, "ask": 3, "Is": [3, 256, 259], "shop": 3, "bui": 3, "food": 3, "633": 3, "21": [3, 314], "475": 3, "su": 3, "j": [3, 6, 114, 217, 301, 302, 304], "lu": 3, "pan": 3, "murtadha": 3, "wen": 3, "liu": 3, "2021": 3, "roform": [3, 252], "enhanc": [3, 252, 324], "rotari": [3, 252], "arxiv": [3, 221, 222, 223, 227, 250, 285, 299, 305], "preprint": [3, 299, 305], "2104": 3, "09864": 3, "zhang": 3, "sennrich": 3, "2019": [3, 303], "root": [3, 159, 174, 250], "advanc": [3, 321], "inform": [3, 4, 6, 161, 162, 209, 213, 220, 247, 322, 328], "system": [3, 6], "shazeer": 3, "2020": 3, "glu": 3, "variant": [3, 283, 304], "2002": 3, "05202": 3, "classifi": 4, "mnist": 4, "As": [4, 183, 209, 321], "mlp": [4, 209, 259, 298], "inherit": [4, 320], "standard": [4, 36, 56, 127, 147, 149, 259, 261, 263, 266, 326], "idiom": [4, 321], "input_dim": [4, 209, 224, 249], "hidden_dim": [4, 296, 298], "output_dim": [4, 209, 224, 249], "layer_s": 4, "idim": 4, "odim": 4, "maximum": [4, 21, 62, 209, 251, 256, 269, 270, 287, 296, 324], "cross": [4, 271, 273], "entropi": [4, 271, 273], "sub": [4, 74, 152], "commonli": [4, 245, 293, 321], "cross_entropi": [4, 209], "accuraci": 4, "valid": [4, 67, 98, 195, 200, 233, 244, 320], "eval_fn": 4, "argmax": 4, "loader": 4, "num_class": [4, 298], "batch_siz": [4, 298], "num_epoch": [4, 298], "learning_r": [4, 298, 299, 300, 301, 302, 303, 304, 305, 307, 310, 311, 312, 313, 314, 321], "train_imag": [4, 298], "train_label": [4, 298], "test_imag": 4, "test_label": 4, "shuffl": 4, "minibatch": 4, "batch_iter": [4, 298], "perm": 4, "permut": 4, "id": [4, 6], "put": [4, 321], "trainabl": [4, 199, 209, 296], "loss_and_grad_fn": [4, 298, 321, 322], "value_and_grad": [4, 209, 245, 296, 298, 309, 321, 322, 325, 326], "epoch": 4, "test": [4, 6], "confus": 4, "decent": 4, "95": 4, "brought": 5, "research": 5, "except": [5, 85, 92, 93, 95, 96, 97, 221, 235, 323, 325], "featur": [5, 65, 66, 213, 221, 222, 223, 224, 249, 250, 252, 259, 321, 324], "main": [5, 74, 85, 201, 209], "differ": [5, 180, 283, 322], "lazi": [5, 296, 326], "multi": [5, 214, 215, 323, 325], "cpu": [5, 115, 321, 328], "gpu": [5, 321, 323, 328], "inspir": 5, "jax": [5, 318], "arrayfir": 5, "unifi": 5, "live": [5, 328], "guid": 5, "convers": 5, "regress": [5, 279], "layer": [5, 209, 211, 212, 217, 218, 221, 223, 224, 225, 226, 241, 246, 249, 254, 259, 292, 296], "perceptron": 5, "llm": 5, "infer": [5, 101, 117], "fft": 5, "algebra": 5, "tree": [5, 63, 82, 102, 192, 195, 200, 201, 202, 306, 307, 309, 316, 322], "develop": [5, 6], "document": [5, 46, 57, 161, 162, 321, 322, 323], "pypi": 6, "meet": 6, "seri": 6, "chip": 6, "nativ": 6, "maco": 6, "13": 6, "recommend": [6, 305], "14": 6, "sonoma": 6, "conda": 6, "forg": 6, "distribut": [6, 145, 146, 147, 149, 153, 154, 224, 261, 262, 263, 264, 266, 267, 274, 277, 282, 284, 293], "probabl": [6, 150, 216, 217, 218, 249, 271, 273, 277, 328], "platform": 6, "processor": 6, "arm": [6, 205], "i386": 6, "switch": 6, "17": 6, "g": [6, 114, 143, 292, 310, 311, 324, 329], "clang": 6, "cmake": 6, "24": 6, "xcode": 6, "15": [6, 114, 321], "environ": [6, 75, 78], "via": [6, 306, 309, 324, 325], "rosetta": 6, "unam": 6, "p": [6, 145, 209, 216, 217, 218, 284, 302, 304], "clone": 6, "git": 6, "com": 6, "ml": 6, "explor": 6, "cd": 6, "brew": 6, "global": [6, 75, 78, 151, 318, 321], "env": 6, "cmake_build_parallel_level": 6, "edit": [6, 246], "unittest": 6, "discov": 6, "stub": 6, "dev": 6, "generate_stub": 6, "mkdir": 6, "either": [6, 10, 46, 56, 57, 62, 76, 77, 79, 100, 103, 104, 112, 113, 114, 122, 127, 129, 132, 134, 180, 192, 212, 226, 254, 263, 264], "libmlx": 6, "preprocessor": 6, "metal_path": 6, "mlx_build_test": 6, "ON": 6, "mlx_build_exampl": 6, "mlx_build_benchmark": 6, "mlx_build_python_bind": 6, "multipl": [6, 127, 134, 143, 144, 247, 256, 313, 314, 321, 324, 327], "wish": 6, "variabl": [6, 63, 75, 78, 102, 111, 192, 194, 195], "export": 6, "developer_dir": 6, "app": 6, "content": [6, 232, 321], "sdk": 6, "xcrun": 6, "macosx": 6, "show": [6, 205, 321], "unabl": 6, "tool": 6, "select": [6, 196, 228, 232], "sudo": 6, "ouptut": 6, "finder": 6, "iterm": 6, "termin": 6, "click": 6, "uncheck": 6, "window": [6, 211, 212, 225, 226], "restart": 6, "grep": 6, "cmake_host_system_processor": 6, "arm64": 6, "x86_64": 6, "wipe": 6, "cahc": 6, "rf": 6, "devicetyp": 7, "attribut": [7, 8, 25, 203, 240, 296, 316], "kwarg": [8, 163, 164, 329], "union": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 35, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 68, 69, 72, 73, 74, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 149, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 196, 197, 198, 211, 212, 215, 225, 226, 233, 235, 244, 311], "absolut": [9, 12, 269, 270, 283], "semant": [10, 60, 76, 77, 79, 103, 104, 112, 113, 122, 127, 129, 132, 134, 180, 328], "keepdim": [11, 13, 21, 22, 28, 29, 30, 31, 39, 40, 41, 42, 44, 55, 58, 114, 126, 128, 130, 131, 142, 171, 181, 193], "reduct": [11, 13, 126, 128, 131, 142, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284], "reduc": [11, 13, 21, 22, 126, 128, 130, 131, 142, 181, 193, 213, 259, 279], "unspecifi": [11, 13, 14, 21, 22, 23, 24, 64, 101, 126, 128, 130, 131, 137, 141, 142, 156, 171, 172, 181, 183, 193, 197, 329], "entir": [11, 13, 21, 22, 126, 128, 130, 131, 142, 181, 193, 217, 218], "singleton": [11, 13, 21, 22, 126, 127, 128, 130, 131, 142, 181, 193], "rtol": 12, "05": [12, 213, 221, 222, 223, 250], "atol": 12, "08": [12, 272, 301, 302, 303, 304, 310], "equal_nan": [12, 59], "approxim": [12, 220, 268, 269, 270], "comparison": [12, 79, 103, 104, 112, 113], "infinit": 12, "equal": [12, 23, 59, 85, 104, 113, 141, 150, 173, 222, 224], "sign": [12, 205, 305], "nan": [12, 59, 108], "ab": [12, 114, 192, 221, 222, 223, 227, 250, 285, 321], "array_equ": 12, "rel": [12, 300, 321], "toler": 12, "boolean": [12, 59, 107, 108, 109, 110, 123, 124, 125, 205, 243, 323], "interv": [14, 116, 150, 154], "increment": 14, "otherwis": [14, 200, 201, 233, 235, 244, 258, 259, 271, 276, 283, 290, 291, 324, 325], "int32": [14, 98, 114, 150, 205, 323, 326], "convent": [14, 67, 303], "lead": [14, 321], "fraction": 14, "integr": [14, 183, 324], "invers": [15, 16, 17, 18, 19, 20, 81, 89, 90, 91, 92, 93, 94], "cosin": [15, 16, 68, 69, 272, 312, 322], "hyperbol": [16, 18, 20, 69, 170, 186], "sine": [17, 18, 169, 170, 322], "uint32": [21, 22, 23, 24, 146, 205], "minimum": [22, 62, 256, 272], "kth": [23, 141], "partit": 23, "order": [23, 114, 141, 143, 209, 221, 245, 254, 307, 321, 322], "undefin": [23, 141, 323], "sort": [23, 24, 141], "flatten": [23, 24, 114, 139, 141, 156, 172, 183, 184, 200], "dimension": [25, 86, 87, 88, 89, 90, 91, 95, 96, 97, 211, 212, 213, 214, 215, 219, 224, 225, 226, 249, 256, 323, 325], "val": [25, 101], "tupl": [25, 46, 49, 57, 64, 66, 77, 82, 84, 111, 114, 115, 140, 143, 157, 176, 192, 194, 200, 201, 202, 211, 212, 215, 225, 226, 235, 237, 254, 300, 302, 303, 304, 305, 320, 322], "ndarrai": [25, 323, 324, 326], "properti": [26, 34, 43, 49, 51, 240, 243, 308, 322], "argument": [26, 46, 57, 63, 82, 102, 192, 201, 209, 318, 322, 327, 328, 329], "decim": [47, 158], "indices_or_sect": [52, 173], "nest": [56, 63, 209, 296, 320, 322], "ddof": [58, 193], "a_min": 62, "a_max": 62, "edg": [62, 140, 321], "At": 62, "anoth": [62, 127, 180, 196, 209, 228, 321, 322, 323, 328], "fun": [63, 102, 111, 192, 194, 195, 321, 323, 324, 328], "dict": [63, 82, 117, 161, 162, 163, 238, 242, 245, 246, 296, 306, 307, 309, 320, 322, 327], "dure": [63, 216, 217, 218, 325], "arbitrarili": [63, 209, 320, 322, 326], "leaf": [63, 200, 201, 232], "node": [63, 82, 195], "pad": [65, 66, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 211, 212, 214, 215, 225, 226], "dilat": [65, 66], "group": [65, 66, 72, 143, 144, 221, 249], "1d": [65, 67, 161, 184], "convolut": [65, 66, 67, 214, 215, 217, 218], "channel": [65, 66, 213, 214, 215, 217, 218], "c_in": [65, 66], "c_out": [65, 66], "convolv": [65, 66], "2d": [66, 74, 143, 213, 217], "spatial": [66, 211, 221, 225], "symmetr": 66, "discret": [67, 86, 87, 88, 89, 90, 91, 95, 96, 97, 219], "swap": [67, 182, 246, 249], "conv": 67, "filter": [67, 214, 215, 228, 232], "flip": 67, "signal": 67, "bias": [72, 143, 144, 233, 244, 247], "group_siz": [72, 143, 144, 249], "64": [72, 143, 144, 205, 249], "configur": 72, "formal": [72, 143], "notat": [72, 200, 237], "quantiz": [72, 117, 144, 249], "w_i": [72, 143], "hat": [72, 143], "occupi": [72, 143, 144], "diagon": [73, 85, 189, 190, 191], "th": [73, 85], "axis1": [74, 182], "axis2": [74, 182], "subarrai": [74, 173], "remov": [74, 127, 146, 176, 273], "insert": [74, 84, 328], "neg": [74, 98, 109, 225, 226, 247, 274, 282, 284, 323], "taken": [74, 183], "disabl": [75, 321], "mlx_disable_compil": [75, 78, 321], "divis": [76, 100, 143], "quotient": [76, 77, 100], "remaind": 77, "fuction": 77, "faster": [77, 268, 321, 322], "mathrm": [80, 167, 222], "frac": [80, 143, 167, 211, 212, 213, 216, 217, 218, 221, 222, 223, 224, 225, 226, 250, 261, 262, 263, 264, 272, 274, 276, 279, 299, 301, 302, 303, 304, 310], "pi": [80, 256, 322], "int_0": 80, "dt": 80, "erf": [81, 321], "exponenti": [83, 253, 288, 313], "ident": [85, 178, 209, 241], "zero": [85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 189, 190, 191, 198, 209, 211, 212, 216, 217, 218, 235, 260, 261, 262, 263, 264, 265, 266, 267, 293, 300, 323], "whose": [85, 199], "One": [86, 89, 95, 159, 321, 322], "fourier": [86, 87, 88, 89, 90, 91, 95, 96, 97], "truncat": [86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 153], "dft": [86, 87, 88, 89, 90, 91, 95, 96, 97], "rfft": 92, "real": [92, 93, 94, 95, 96, 97], "rfft2": 93, "rfftn": 94, "silent": [95, 96, 97], "start_axi": 98, "end_axi": 98, "inclus": 98, "outsid": 98, "clamp": 98, "integ": [100, 114, 140, 143, 144, 145, 150, 173, 187, 195, 205, 219, 323], "floor": 100, "argnam": [102, 192], "neither": [102, 192], "keyword": [102, 163, 164, 192, 201, 209, 318, 327, 329], "strict": [103, 112, 233, 235, 244], "ordinari": 106, "inifn": 107, "infin": [107, 109, 110, 225, 226, 304], "ord": 114, "tabl": [114, 205, 219], "frobeniu": 114, "matric": [114, 115], "strictli": 114, "mathemat": 114, "variou": 114, "purpos": 114, "calcul": [114, 274, 280, 300], "fro": 114, "inf": [114, 247], "largest": 114, "sing": 114, "smallest": 114, "singular": 114, "nuclear": 114, "_f": 114, "sum_": [114, 211, 212, 279], "a_": 114, "valueerror": [114, 235, 322], "refer": [114, 222, 227, 240, 261, 262, 263, 264, 285, 323], "golub": 114, "van": 114, "loan": 114, "baltimor": 114, "md": 114, "john": 114, "hopkin": 114, "univers": 114, "1985": 114, "pg": 114, "la": 114, "arang": [114, 323, 325], "9": [114, 273, 299, 302, 303, 304, 305, 307, 313, 314, 325], "74597": 114, "20": 114, "84804": 114, "41421": 114, "23607": [114, 115], "74166": 114, "24264": 114, "11": 114, "225": 114, "factorizatoin": 115, "q": 115, "894427": 115, "447214": 115, "57771": 115, "50": 116, "evenli": 116, "return_metadata": 117, "binari": [117, 160, 161, 162, 163, 164, 258, 271, 291, 321], "npy": [117, 160, 327], "safetensor": [117, 162, 235, 239, 324, 327], "gguf": [117, 161, 327], "matadata": 117, "unsupport": 117, "tensor": [117, 187, 211, 212, 225, 226, 284, 325], "natur": [118, 120, 324], "logarithm": [118, 119, 120, 121], "log": [120, 122, 126, 274, 277, 279, 282], "plu": 120, "exp": [122, 126, 147, 171, 253, 277, 288, 321, 328], "stabl": [122, 126, 171, 279], "prepend": 127, "negat": 135, "beforehand": 139, "pad_with": 140, "constant_valu": 140, "pad_width": 140, "before_1": 140, "after_1": 140, "before_2": 140, "after_2": 140, "before_n": 140, "after_n": 140, "before_i": 140, "after_i": 140, "extend": 140, "side": [140, 211, 212, 225, 226, 321], "smaller": [141, 305, 321], "everi": [143, 201, 314, 322], "particular": [143, 221], "consecut": [143, 252], "w_1": 143, "w_g": 143, "begin": [143, 212, 226, 253, 258, 276, 283, 288, 290, 291], "align": [143, 212, 226], "max_i": 143, "min_i": 143, "textrm": [143, 220, 268], "round": 143, "pack": [143, 144], "unsign": [143, 144, 205], "lower": [143, 150, 153, 154, 189, 267], "upper": [143, 150, 153, 154, 267], "1st": 143, "signific": 143, "2nd": 143, "dequant": 143, "w_q": 143, "whether": [144, 232, 247, 271, 274, 280], "prng": [145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 318], "num_sampl": 146, "unnorm": [146, 271, 273], "draw": 146, "cdf": [147, 220, 268], "accord": [147, 196, 247, 261, 262, 263, 264], "seed": 148, "loc": 149, "deviat": [149, 261, 263, 266], "low": [150, 154, 267, 293], "high": [150, 154, 209, 219, 267, 293], "bound": [150, 153, 154, 220, 267, 321, 323, 328], "roadcast": 150, "domain": 153, "uniformli": 154, "repetit": 156, "preserv": [157, 322], "reciproc": 159, "arr": [160, 323], "obj": 161, "uncompress": 163, "my_path": 163, "tree_flatten": [163, 201, 202, 209], "transformerencod": 163, "128": [163, 209], "flat_param": 163, "compress": 164, "being": [178, 209], "prevent": [178, 284, 325], "flow": [178, 324], "unchang": [178, 252], "prior": [183, 184], "exclud": 184, "dot": [187, 200, 237, 247], "elsewher": [189, 323], "col": 189, "triangl": 189, "mse": 192, "param": [192, 209, 293, 322], "lvalu": 192, "dlvalu": 192, "dparam": 192, "lasso": 192, "l1": [192, 276, 278, 279, 283], "varianc": [193, 213, 221, 274], "divisor": 193, "cotang": 194, "in_ax": [195, 322], "out_ax": [195, 322], "prefix": [195, 200], "fn": [199, 201, 326], "callabl": [199, 200, 201, 228, 229, 232, 254, 259, 260, 261, 262, 263, 264, 265, 266, 267, 311], "wrt": 199, "rho": 299, "06": [274, 284, 299], "paper": [213, 256, 299, 300, 301, 302, 304, 305], "zeiler": 299, "2012": [299, 310], "adapt": [299, 300, 301], "1212": 299, "5701": 299, "v_": [299, 301, 302, 303, 304, 310, 311], "v_t": [299, 301, 302, 303, 304, 310, 311], "g_t": [299, 301, 302, 303, 304, 305, 310, 311], "delta": [276, 299], "w_": [212, 226, 299, 300, 301, 302, 303, 304, 305, 310, 311], "u_t": 299, "epsilon": [213, 221, 222, 223, 250, 272, 274, 299, 301, 302, 303, 304, 310], "u_": 299, "w_t": [299, 301, 302, 303, 304, 305, 310, 311], "lambda": [201, 209, 228, 233, 253, 257, 288, 290, 299, 300, 301, 302, 303, 304, 305, 310, 311, 321, 322], "averag": [211, 212, 299, 300, 302, 303, 304], "denomin": [222, 272, 299, 301, 302, 303, 304, 310], "stabil": [213, 221, 222, 223, 250, 272, 274, 299, 300, 301, 302, 303, 304, 310], "30": 300, "001": 300, "clip_threshold": 300, "decay_r": [300, 313, 314], "beta_1": [300, 302, 303, 304, 305], "weight_decai": [300, 303, 305, 311], "scale_paramet": 300, "relative_step": 300, "warmup_init": 300, "sublinear": 300, "cost": [300, 324], "epsilon_1": 300, "epsilon_2": 300, "parameter_scal": 300, "clip": 300, "unscal": 300, "decai": [300, 303, 305, 311, 312, 313, 314], "duchi": 301, "hazan": 301, "singer": 301, "2011": 301, "subgradi": 301, "onlin": 301, "stochast": [301, 302, 304, 311, 324], "jmlr": 301, "999": [302, 303, 304], "omit": [302, 304], "estim": [302, 304], "kingma": [302, 304], "ba": [302, 304], "2015": [217, 302, 304], "iclr": [302, 303, 304], "m_": [302, 303, 304, 305], "m_t": [302, 303, 304, 305], "beta_2": [302, 303, 304, 305], "contrast": 303, "loshchilov": 303, "hutter": 303, "decoupl": 303, "regular": [217, 227, 285, 303, 321, 323], "adam": [298, 304, 305], "99": [305, 310], "tend": 305, "larger": [252, 305], "10x": 305, "adamw": [298, 305], "maintain": [217, 218, 305], "strength": [305, 311], "wd": 305, "chen": 305, "symbol": 305, "discoveri": 305, "2302": 305, "06675": 305, "c_": 305, "eta": 305, "c_t": 305, "momentum": [213, 305, 307, 311, 321], "appli": [201, 209, 211, 212, 213, 214, 215, 217, 218, 220, 221, 222, 223, 224, 225, 226, 227, 229, 241, 248, 249, 250, 251, 253, 255, 257, 258, 268, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 293, 306, 309, 316, 321], "opt": 306, "superset": [201, 306], "trainable_paramet": [209, 232, 307], "tieleman": 310, "hinton": 310, "lectur": 310, "coursera": 310, "smooth": [273, 283, 310], "dampen": 311, "nesterov": 311, "descent": [311, 321, 324], "mu": 311, "tau": 311, "l2": [276, 279, 311], "penalti": 311, "is_leaf": [200, 201], "arbitrari": [200, 296], "depth": [200, 218, 322], "hello": [200, 202], "charact": 200, "flat": [200, 202], "extra": 201, "closer": 201, "constitut": 201, "dict_kei": [201, 307], "recreat": 202, "world": 202, "42": 202, "byte": 205, "bool_": 205, "uint8": 205, "uint16": 205, "16": [205, 211, 222, 225, 228, 296], "uint64": 205, "int8": 205, "int16": 205, "int64": 205, "done": [209, 216, 321, 324, 325], "manual": 209, "explicitli": [209, 318], "solv": 209, "intuit": 209, "freez": [209, 244, 296], "finetun": 209, "in_dim": [209, 296], "out_dim": [209, 296], "enumer": 209, "caus": [209, 321, 324], "local": [209, 217], "scope": 209, "l2_loss": 209, "y_hat": 209, "loss_and_grad": 209, "workhors": 209, "Its": 209, "recurs": [209, 232, 233, 238, 242, 244, 296], "frozen": [209, 233, 242, 244, 249, 296], "individu": [209, 217, 218], "subset": [209, 232], "action": 209, "displai": 209, "tree_map": 209, "count": 209, "num_param": 209, "preclud": 209, "pure": [209, 298], "pattern": [209, 324], "achiev": 209, "other_input": 209, "necessari": 209, "wrap": 209, "apply_to_modul": [209, 233], "children": 209, "filter_and_map": 209, "leaf_modul": 209, "load_weight": [209, 324], "named_modul": 209, "save_weight": 209, "unfreez": [209, 233], "update_modul": 209, "alibi": 209, "batchnorm": 209, "conv1d": 209, "conv2d": 209, "dropout": [209, 217, 218, 241, 259, 321], "dropout2d": 209, "dropout3d": 209, "gelu": [209, 269, 270, 321], "groupnorm": 209, "instancenorm": 209, "layernorm": 209, "mish": 209, "prelu": 209, "quantizedlinear": 209, "relu": [209, 248, 259, 286, 293], "selu": 209, "sequenti": [209, 293], "silu": 209, "sinusoidalpositionalencod": 209, "softshrink": 209, "gelu_approx": [209, 220, 268], "gelu_fast_approx": [209, 220, 268], "binary_cross_entropi": [209, 321], "cosine_similarity_loss": 209, "gaussian_nll_loss": 209, "hinge_loss": 209, "huber_loss": 209, "kl_div_loss": 209, "l1_loss": 209, "log_cosh_loss": 209, "margin_ranking_loss": 209, "mse_loss": 209, "nll_loss": 209, "smooth_l1_loss": 209, "triplet_loss": 209, "init": [209, 248, 293, 298, 312, 313, 314], "uniform": [209, 224, 235, 262, 264, 293, 318, 321, 322, 328], "glorot_norm": 209, "glorot_uniform": 209, "he_norm": 209, "he_uniform": 209, "affin": [213, 221, 222, 223, 224, 249], "track_running_stat": 213, "var": [213, 221, 222, 223, 274], "gamma": [213, 221, 222, 223, 250, 261, 262, 263, 264], "nc": 213, "nlc": [213, 214], "four": 213, "nhwc": [213, 215], "height": [212, 213, 215, 217, 218, 226], "width": [212, 213, 215, 217, 218, 226, 249], "deep": [213, 261, 262, 263, 264], "intern": 213, "covari": 213, "shift": 213, "bn": 213, "in_channel": [214, 215], "out_channel": [214, 215], "kernel_s": [211, 212, 214, 215, 225, 226], "learnabl": [214, 215, 254], "portion": 216, "independ": [217, 218], "nwhc": 217, "whc": 217, "entri": [217, 218], "benefici": [217, 218, 324], "earli": 217, "adjac": 217, "pixel": 217, "correl": 217, "thompson": 217, "goroshin": 217, "jain": 217, "lecun": 217, "bregler": 217, "cvpr": 217, "ndhwc": 218, "dhwc": 218, "medic": 218, "video": 218, "num_embed": 219, "lookup": 219, "typic": [219, 298, 321, 324], "usual": [219, 320, 324], "vocabulari": 219, "approx": 220, "unit": [220, 251, 253, 255, 261, 262, 263, 264, 268, 269, 270, 287, 288, 289], "phi": [220, 268], "geluapprox": 220, "sigma": [220, 255, 261, 262, 263, 264, 269, 270, 289], "60033": [220, 269], "0433603": [220, 269], "gelufast": 220, "773": [220, 270], "regard": 220, "num_group": 221, "pytorch_compat": 221, "split": 221, "preced": 221, "http": [221, 222, 223, 227, 250, 285], "org": [221, 222, 223, 227, 250, 285], "1803": 221, "08494": 221, "inorm": 222, "1607": [222, 223], "08022": 222, "06450": 223, "mathcal": 224, "u": 224, "d_i": 224, "monoton": [227, 285], "1908": [227, 285], "08681": [227, 285], "tanh": [227, 285], "softplu": [227, 285], "map_fn": [228, 232], "filter_fn": [228, 232], "valid_parameter_filt": 228, "apply_fn": 229, "descend": 230, "is_leaf_fn": 232, "found": 232, "drop": 232, "idempot": [233, 244], "attent": [233, 247, 256, 259], "endswith": 233, "file_or_weight": 235, "miss": [235, 327], "ok": [235, 322], "save_safetensor": [239, 327], "reflect": [240, 321, 323, 325], "certain": [241, 321], "ie": 244, "noop": 244, "unfrozen": 244, "chang": [166, 245, 249, 276, 283, 321, 325], "tracer": 245, "partial": [245, 246, 321, 324], "child": 246, "programmat": 246, "query_input_dim": 247, "key_input_dim": 247, "value_input_dim": 247, "value_dim": 247, "value_output_dim": 247, "head": [247, 259], "aggreg": 247, "linearli": 247, "attend": 247, "num_paramet": 248, "25": 248, "parametr": [248, 286], "classmethod": 249, "from_linear": 249, "quantize_modul": 249, "1910": 250, "07467": 250, "rectifi": [251, 263, 264, 287], "10000": 252, "rotat": 252, "slightli": [252, 328], "angular": 252, "frequenc": [252, 256], "_cos_sin_theta_kei": [], "precomput": [], "_cos_sin_theta_valu": [], "leq": [253, 276, 288], "0507": [253, 288], "67326": [253, 288], "elu": [253, 288], "plain": 254, "known": [255, 289], "swish": [255, 289], "cdot": [255, 269, 270, 272, 275, 289], "min_freq": 256, "0001": 256, "max_freq": 256, "cos_first": 256, "full_turn": 256, "sinusoid": 256, "sin": [256, 322, 326], "lambd": [257, 290], "threshold": [258, 276, 283, 291], "geq": [258, 291], "num_encoder_lay": 259, "num_decoder_lay": 259, "custom_encod": 259, "custom_decod": 259, "norm_first": 259, "checkpoint": 259, "decod": 259, "interact": 259, "mechan": 259, "hidden": 259, "chekpoint": 259, "usag": [259, 321], "expens": 259, "init_fn": [260, 261, 262, 263, 264, 265, 266, 267, 293], "glorot": [261, 262], "fan_in": [261, 262, 263, 264], "fan_out": [261, 262, 263, 264], "difficulti": [261, 262], "feedforward": [261, 262], "191107": 261, "61278": 261, "150594": 261, "363207": 261, "gain": [261, 262, 263, 264], "89613": 261, "53947": 261, "48095": 261, "995016": 261, "223404": 262, "890597": 262, "379159": 262, "776856": 262, "90041": 262, "02264": 262, "912766": 262, "12451": 262, "fan": [263, 264], "delv": [263, 264], "surpass": [263, 264], "human": [263, 264], "level": [263, 264], "imagenet": [263, 264], "classif": [263, 264], "25211": 263, "458835": 263, "177208": 263, "0137595": 263, "6967": 263, "02765": 263, "15268": 263, "75787": 263, "kaim": 264, "0300242": 264, "0184009": 264, "793615": 264, "666329": 264, "64331": 264, "16506": 264, "08619": 264, "79854": 264, "982273": 266, "534422": 266, "380709": 266, "0645099": 266, "883935": 267, "863726": 267, "617261": 267, "417497": 267, "exact": [269, 270], "0003": 269, "015": 270, "with_logit": 271, "predict": [271, 274, 275, 276, 277, 278, 279, 281, 282, 283], "105361": 271, "223144": 271, "20397": 271, "916291": 271, "539245": 271, "prob": 271, "510826": 271, "x1": 272, "x2": 272, "x_1": [272, 280], "x_2": [272, 280], "label_smooth": 273, "hot": 273, "0485873": 273, "348587": 273, "likelihood": [274, 282], "nll": [274, 282], "hing": 275, "y_": [275, 279], "pred": [275, 279], "huber": 276, "l_": [211, 225, 276], "kullback": 277, "leibler": 277, "diverg": 277, "cosh": 279, "logcosh": 279, "sensit": 279, "outlier": 279, "dual": 279, "behavior": [279, 323, 324], "offer": 279, "balanc": 279, "robust": 279, "approach": [279, 322], "task": 279, "inputs1": 280, "inputs2": 280, "margin": [280, 284], "rank": 280, "573409": 280, "765166": 280, "0638": 280, "75596": 280, "225763": 280, "256995": 280, "773433": 280, "formula": 283, "anchor": 284, "triplet": 284, "_p": 284, "degre": 284, "pairwis": 284, "instabl": 284, "subclass": 296, "concept": 296, "mymlp": 296, "in_proj": 296, "basi": 316, "subsequ": 298, "apply_gradi": 298, "implicit": [318, 321, 322], "fine": [318, 324], "grain": 318, "control": [318, 324], "manag": [179, 318, 328], "pseudo": 318, "altern": 318, "splittabl": 318, "threefri": 318, "counter": 318, "cycl": 320, "merg": 321, "fuse": 321, "big": 321, "awar": [321, 324], "36788": 321, "compiled_fun": 321, "code": [321, 324], "slow": 321, "Not": 321, "recompil": 321, "stack": 321, "rerun": [321, 324], "too": [321, 324], "frequent": [321, 324], "destroi": 321, "anonym": 321, "don": [321, 328], "nonlinear": 321, "unari": 321, "overhead": [321, 324, 328], "bandwidth": 321, "fusibl": 321, "consider": 321, "versu": 321, "timeit": [321, 322], "tic": 321, "perf_count": 321, "toc": 321, "tpi": 321, "1e3": 321, "1000": [312, 321], "4096": [321, 322, 328], "On": [321, 322, 324], "millisecond": [321, 328], "five": 321, "latest": 321, "won": 321, "trace": 321, "placehold": 321, "insid": 321, "crash": 321, "inspect": [321, 326], "disable_compil": 321, "okai": [321, 324], "intend": 321, "deal": 321, "pretti": [321, 324], "inconveni": 321, "functool": 321, "particularli": 321, "backward": [321, 322], "squeez": 321, "checkout": 321, "compiled_grad_fn": 321, "71828": 321, "outer": [321, 324], "opportun": 321, "idea": [322, 324], "behind": 322, "dfdx": [322, 323], "d2fdx2": 322, "differentiaion": 322, "zero_grad": 322, "detach": 322, "requires_grad": 322, "dloss_dw": 322, "dloss_dx": 322, "lot": 322, "redund": 322, "suppos": [322, 328], "nice": [322, 324], "propag": [322, 323], "stop_gradi": 322, "autom": 322, "contriv": [322, 328], "sake": 322, "clariti": 322, "quit": [322, 325], "power": [322, 325], "difficult": 322, "primit": 322, "issu": [322, 325], "priorit": 322, "xs": 322, "ys": 322, "naive_add": 322, "vmap_add": 322, "total": 322, "390": 322, "wherea": 322, "025": 322, "ten": [322, 324], "Of": 322, "better": [322, 328], "handi": 322, "slice": 323, "ellipsi": 323, "syntax": 323, "idx": 323, "mix": 323, "take_along_axi": 323, "lack": 323, "extrem": [323, 324], "ineffici": [323, 324], "nonzero": 323, "record": 324, "dynam": 324, "easier": 324, "worri": 324, "fun1": 324, "expensive_fun": 324, "consum": 324, "eager": 324, "thank": 324, "weights_fp16": 324, "trade": 324, "bad": 324, "grow": 324, "computation": 324, "costli": 324, "wide": 324, "thousand": 324, "value_and_grad_fn": 324, "implicitli": 324, "anytim": 324, "memoryview": [324, 325], "perfectli": 324, "first_lay": 324, "second_layer_a": 324, "second_layer_b": 324, "protocol": 325, "receiv": 325, "pep": 325, "3118": 325, "view": 325, "a_view": 325, "owndata": 325, "extern": 325, "x_view": 325, "modifi": 325, "df": 325, "x\u00b2": 325, "2x": 325, "indirectli": 325, "modif": 325, "seen": 325, "occur": 325, "incorpor": 325, "incorrect": 325, "experiment": 325, "break": 325, "advis": 325, "intermedi": 325, "jnp": 325, "tf": 325, "page": 326, "composit": 326, "archiv": 327, "savez_compress": 327, "save_gguf": 327, "arr_0": 327, "pool": [211, 212, 225, 226, 328], "advantag": 328, "parallel": 328, "race": 328, "interest": 328, "albeit": 328, "d1": 328, "d2": 328, "matmul": 328, "dens": 328, "twice": 328, "measur": 328, "default_stream": 329, "default_devic": 329, "my_devic": 329, "streamcontext": 179, "context": 179, "avgpool1d": 209, "avgpool2d": 209, "maxpool1d": 209, "maxpool2d": [209, 212], "n_i": [211, 212, 225, 226], "c_j": [211, 212, 225, 226], "ldot": [211, 212, 225, 226], "lfloor": [211, 212, 225, 226], "rfloor": [211, 212, 225, 226], "k_h": [212, 226], "k_w": [212, 226], "h_": [212, 226], "max_": [225, 226], "rmsprop": 298, "adagrad": 298, "adafactor": 298, "adadelta": 298, "adamax": 298, "lion": 298, "step_decai": 298, "exponential_decai": 298, "cosine_decai": 298, "decay_step": 312, "beyond": 312, "lr_schedul": [312, 313, 314], "0999961": 312, "06561": 313, "step_siz": 314, "081": 314}, "objects": {"mlx.core": [[7, 0, 1, "", "Device"], [8, 0, 1, "", "Dtype"], [203, 0, 1, "", "Stream"], [9, 2, 1, "", "abs"], [10, 2, 1, "", "add"], [11, 2, 1, "", "all"], [12, 2, 1, "", "allclose"], [13, 2, 1, "", "any"], [14, 2, 1, "", "arange"], [15, 2, 1, "", "arccos"], [16, 2, 1, "", "arccosh"], [17, 2, 1, "", "arcsin"], [18, 2, 1, "", "arcsinh"], [19, 2, 1, "", "arctan"], [20, 2, 1, "", "arctanh"], [21, 2, 1, "", "argmax"], [22, 2, 1, "", "argmin"], [23, 2, 1, "", "argpartition"], [24, 2, 1, "", "argsort"], [25, 0, 1, "", "array"], [59, 2, 1, "", "array_equal"], [60, 2, 1, "", "broadcast_to"], [61, 2, 1, "", "ceil"], [62, 2, 1, "", "clip"], [63, 2, 1, "", "compile"], [64, 2, 1, "", "concatenate"], [65, 2, 1, "", "conv1d"], [66, 2, 1, "", "conv2d"], [67, 2, 1, "", "convolve"], [68, 2, 1, "", "cos"], [69, 2, 1, "", "cosh"], [70, 2, 1, "", "default_device"], [71, 2, 1, "", "default_stream"], [72, 2, 1, "", "dequantize"], [73, 2, 1, "", "diag"], [74, 2, 1, "", "diagonal"], [75, 2, 1, "", "disable_compile"], [76, 2, 1, "", "divide"], [77, 2, 1, "", "divmod"], [78, 2, 1, "", "enable_compile"], [79, 2, 1, "", "equal"], [80, 2, 1, "", "erf"], [81, 2, 1, "", "erfinv"], [82, 2, 1, "", "eval"], [83, 2, 1, "", "exp"], [84, 2, 1, "", "expand_dims"], [85, 2, 1, "", "eye"], [98, 2, 1, "", "flatten"], [99, 2, 1, "", "floor"], [100, 2, 1, "", "floor_divide"], [101, 2, 1, "", "full"], [102, 2, 1, "", "grad"], [103, 2, 1, "", "greater"], [104, 2, 1, "", "greater_equal"], [105, 2, 1, "", "identity"], [106, 2, 1, "", "inner"], [107, 2, 1, "", "isinf"], [108, 2, 1, "", "isnan"], [109, 2, 1, "", "isneginf"], [110, 2, 1, "", "isposinf"], [111, 2, 1, "", "jvp"], [112, 2, 1, "", "less"], [113, 2, 1, "", "less_equal"], [116, 2, 1, "", "linspace"], [117, 2, 1, "", "load"], [118, 2, 1, "", "log"], [119, 2, 1, "", "log10"], [120, 2, 1, "", "log1p"], [121, 2, 1, "", "log2"], [122, 2, 1, "", "logaddexp"], [123, 2, 1, "", "logical_and"], [124, 2, 1, "", "logical_not"], [125, 2, 1, "", "logical_or"], [126, 2, 1, "", "logsumexp"], [127, 2, 1, "", "matmul"], [128, 2, 1, "", "max"], [129, 2, 1, "", "maximum"], [130, 2, 1, "", "mean"], [131, 2, 1, "", "min"], [132, 2, 1, "", "minimum"], [133, 2, 1, "", "moveaxis"], [134, 2, 1, "", "multiply"], [135, 2, 1, "", "negative"], [136, 2, 1, "", "new_stream"], [137, 2, 1, "", "ones"], [138, 2, 1, "", "ones_like"], [139, 2, 1, "", "outer"], [140, 2, 1, "", "pad"], [141, 2, 1, "", "partition"], [142, 2, 1, "", "prod"], [143, 2, 1, "", "quantize"], [144, 2, 1, "", "quantized_matmul"], [155, 2, 1, "", "reciprocal"], [156, 2, 1, "", "repeat"], [157, 2, 1, "", "reshape"], [158, 2, 1, "", "round"], [159, 2, 1, "", "rsqrt"], [160, 2, 1, "", "save"], [161, 2, 1, "", "save_gguf"], [162, 2, 1, "", "save_safetensors"], [163, 2, 1, "", "savez"], [164, 2, 1, "", "savez_compressed"], [165, 2, 1, "", "set_default_device"], [166, 2, 1, "", "set_default_stream"], [167, 2, 1, "", "sigmoid"], [168, 2, 1, "", "sign"], [169, 2, 1, "", "sin"], [170, 2, 1, "", "sinh"], [171, 2, 1, "", "softmax"], [172, 2, 1, "", "sort"], [173, 2, 1, "", "split"], [174, 2, 1, "", "sqrt"], [175, 2, 1, "", "square"], [176, 2, 1, "", "squeeze"], [177, 2, 1, "", "stack"], [178, 2, 1, "", "stop_gradient"], [179, 2, 1, "", "stream"], [180, 2, 1, "", "subtract"], [181, 2, 1, "", "sum"], [182, 2, 1, "", "swapaxes"], [183, 2, 1, "", "take"], [184, 2, 1, "", "take_along_axis"], [185, 2, 1, "", "tan"], [186, 2, 1, "", "tanh"], [187, 2, 1, "", "tensordot"], [188, 2, 1, "", "transpose"], [189, 2, 1, "", "tri"], [190, 2, 1, "", "tril"], [191, 2, 1, "", "triu"], [192, 2, 1, "", "value_and_grad"], [193, 2, 1, "", "var"], [194, 2, 1, "", "vjp"], [195, 2, 1, "", "vmap"], [196, 2, 1, "", "where"], [197, 2, 1, "", "zeros"], [198, 2, 1, "", "zeros_like"]], "mlx.core.Device": [[7, 1, 1, "", "__init__"]], "mlx.core.Dtype": [[8, 1, 1, "", "__init__"]], "mlx.core.Stream": [[203, 1, 1, "", "__init__"]], "mlx.core.array": [[26, 3, 1, "", "T"], [25, 1, 1, "", "__init__"], [27, 1, 1, "", "abs"], [28, 1, 1, "", "all"], [29, 1, 1, "", "any"], [30, 1, 1, "", "argmax"], [31, 1, 1, "", "argmin"], [32, 1, 1, "", "astype"], [33, 1, 1, "", "cos"], [34, 3, 1, "", "dtype"], [35, 1, 1, "", "exp"], [36, 1, 1, "", "item"], [37, 1, 1, "", "log"], [38, 1, 1, "", "log1p"], [39, 1, 1, "", "logsumexp"], [40, 1, 1, "", "max"], [41, 1, 1, "", "mean"], [42, 1, 1, "", "min"], [43, 3, 1, "", "ndim"], [44, 1, 1, "", "prod"], [45, 1, 1, "", "reciprocal"], [46, 1, 1, "", "reshape"], [47, 1, 1, "", "round"], [48, 1, 1, "", "rsqrt"], [49, 3, 1, "", "shape"], [50, 1, 1, "", "sin"], [51, 3, 1, "", "size"], [52, 1, 1, "", "split"], [53, 1, 1, "", "sqrt"], [54, 1, 1, "", "square"], [55, 1, 1, "", "sum"], [56, 1, 1, "", "tolist"], [57, 1, 1, "", "transpose"], [58, 1, 1, "", "var"]], "mlx.core.fft": [[86, 2, 1, "", "fft"], [87, 2, 1, "", "fft2"], [88, 2, 1, "", "fftn"], [89, 2, 1, "", "ifft"], [90, 2, 1, "", "ifft2"], [91, 2, 1, "", "ifftn"], [92, 2, 1, "", "irfft"], [93, 2, 1, "", "irfft2"], [94, 2, 1, "", "irfftn"], [95, 2, 1, "", "rfft"], [96, 2, 1, "", "rfft2"], [97, 2, 1, "", "rfftn"]], "mlx.core.linalg": [[114, 2, 1, "", "norm"], [115, 2, 1, "", "qr"]], "mlx.core.random": [[145, 2, 1, "", "bernoulli"], [146, 2, 1, "", "categorical"], [147, 2, 1, "", "gumbel"], [148, 2, 1, "", "key"], [149, 2, 1, "", "normal"], [150, 2, 1, "", "randint"], [151, 2, 1, "", "seed"], [152, 2, 1, "", "split"], [153, 2, 1, "", "truncated_normal"], [154, 2, 1, "", "uniform"]], "mlx.nn": [[210, 0, 1, "", "ALiBi"], [211, 0, 1, "", "AvgPool1d"], [212, 0, 1, "", "AvgPool2d"], [213, 0, 1, "", "BatchNorm"], [214, 0, 1, "", "Conv1d"], [215, 0, 1, "", "Conv2d"], [216, 0, 1, "", "Dropout"], [217, 0, 1, "", "Dropout2d"], [218, 0, 1, "", "Dropout3d"], [219, 0, 1, "", "Embedding"], [220, 0, 1, "", "GELU"], [221, 0, 1, "", "GroupNorm"], [222, 0, 1, "", "InstanceNorm"], [223, 0, 1, "", "LayerNorm"], [224, 0, 1, "", "Linear"], [225, 0, 1, "", "MaxPool1d"], [226, 0, 1, "", "MaxPool2d"], [227, 0, 1, "", "Mish"], [296, 0, 1, "", "Module"], [247, 0, 1, "", "MultiHeadAttention"], [248, 0, 1, "", "PReLU"], [249, 0, 1, "", "QuantizedLinear"], [250, 0, 1, "", "RMSNorm"], [251, 0, 1, "", "ReLU"], [252, 0, 1, "", "RoPE"], [253, 0, 1, "", "SELU"], [254, 0, 1, "", "Sequential"], [255, 0, 1, "", "SiLU"], [256, 0, 1, "", "SinusoidalPositionalEncoding"], [257, 0, 1, "", "Softshrink"], [258, 0, 1, "", "Step"], [259, 0, 1, "", "Transformer"], [268, 2, 1, "", "gelu"], [269, 2, 1, "", "gelu_approx"], [270, 2, 1, "", "gelu_fast_approx"], [285, 2, 1, "", "mish"], [286, 2, 1, "", "prelu"], [287, 2, 1, "", "relu"], [288, 2, 1, "", "selu"], [289, 2, 1, "", "silu"], [290, 2, 1, "", "softshrink"], [291, 2, 1, "", "step"], [199, 2, 1, "", "value_and_grad"]], "mlx.nn.Module": [[228, 1, 1, "", "apply"], [229, 1, 1, "", "apply_to_modules"], [230, 1, 1, "", "children"], [231, 1, 1, "", "eval"], [232, 1, 1, "", "filter_and_map"], [233, 1, 1, "", "freeze"], [234, 1, 1, "", "leaf_modules"], [235, 1, 1, "", "load_weights"], [236, 1, 1, "", "modules"], [237, 1, 1, "", "named_modules"], [238, 1, 1, "", "parameters"], [239, 1, 1, "", "save_weights"], [240, 3, 1, "", "state"], [241, 1, 1, "", "train"], [242, 1, 1, "", "trainable_parameters"], [243, 3, 1, "", "training"], [244, 1, 1, "", "unfreeze"], [245, 1, 1, "", "update"], [246, 1, 1, "", "update_modules"]], "mlx.nn.init": [[260, 2, 1, "", "constant"], [261, 2, 1, "", "glorot_normal"], [262, 2, 1, "", "glorot_uniform"], [263, 2, 1, "", "he_normal"], [264, 2, 1, "", "he_uniform"], [265, 2, 1, "", "identity"], [266, 2, 1, "", "normal"], [267, 2, 1, "", "uniform"]], "mlx.nn.losses": [[271, 2, 1, "", "binary_cross_entropy"], [272, 2, 1, "", "cosine_similarity_loss"], [273, 2, 1, "", "cross_entropy"], [274, 2, 1, "", "gaussian_nll_loss"], [275, 2, 1, "", "hinge_loss"], [276, 2, 1, "", "huber_loss"], [277, 2, 1, "", "kl_div_loss"], [278, 2, 1, "", "l1_loss"], [279, 2, 1, "", "log_cosh_loss"], [280, 2, 1, "", "margin_ranking_loss"], [281, 2, 1, "", "mse_loss"], [282, 2, 1, "", "nll_loss"], [283, 2, 1, "", "smooth_l1_loss"], [284, 2, 1, "", "triplet_loss"]], "mlx.optimizers": [[299, 0, 1, "", "AdaDelta"], [300, 0, 1, "", "Adafactor"], [301, 0, 1, "", "Adagrad"], [302, 0, 1, "", "Adam"], [303, 0, 1, "", "AdamW"], [304, 0, 1, "", "Adamax"], [305, 0, 1, "", "Lion"], [316, 0, 1, "", "Optimizer"], [310, 0, 1, "", "RMSprop"], [311, 0, 1, "", "SGD"], [312, 2, 1, "", "cosine_decay"], [313, 2, 1, "", "exponential_decay"], [314, 2, 1, "", "step_decay"]], "mlx.optimizers.Optimizer": [[306, 1, 1, "", "apply_gradients"], [307, 1, 1, "", "init"], [308, 3, 1, "", "state"], [309, 1, 1, "", "update"]], "mlx.utils": [[200, 2, 1, "", "tree_flatten"], [201, 2, 1, "", "tree_map"], [202, 2, 1, "", "tree_unflatten"]]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function", "3": "py:property"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"], "3": ["py", "property", "Python property"]}, "titleterms": {"oper": [0, 1, 297], "develop": 1, "document": 1, "introduc": 1, "exampl": [1, 5, 321, 328], "primit": 1, "us": [1, 324, 329], "implement": [1, 3], "cpu": 1, "backend": 1, "gpu": 1, "transform": [1, 259, 319, 321, 322, 324, 326], "build": [1, 6], "bind": 1, "python": [1, 5, 6], "cmake": 1, "setuptool": 1, "usag": [1, 5], "result": 1, "script": [1, 3], "download": [1, 3], "code": [1, 3], "linear": [2, 208, 224], "regress": 2, "llm": 3, "infer": 3, "model": 3, "attent": 3, "layer": [3, 4, 294], "encod": 3, "full": [3, 101], "gener": 3, "put": 3, "all": [3, 11, 28], "togeth": 3, "convert": 3, "weight": 3, "load": [3, 117, 327], "benchmark": 3, "multi": 4, "perceptron": 4, "mlx": [5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314], "instal": [5, 6], "api": [5, 6], "refer": 5, "c": [5, 6], "further": 5, "read": 5, "troubleshoot": 6, "from": [6, 323], "sourc": 6, "requir": 6, "option": 6, "metal": 6, "found": 6, "x86": 6, "shell": 6, "core": [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 203], "devic": [7, 206], "dtype": [8, 34], "stream": [179, 203, 206, 329], "ab": [9, 27], "add": 10, "allclos": 12, "ani": [13, 29], "arang": 14, "arcco": 15, "arccosh": 16, "arcsin": 17, "arcsinh": 18, "arctan": 19, "arctanh": 20, "argmax": [21, 30], "argmin": [22, 31], "argpartit": 23, "argsort": 24, "arrai": [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 204, 323, 327], "t": 26, "astyp": 32, "co": [33, 68], "exp": [35, 83], "item": 36, "log": [37, 118], "log1p": [38, 120], "logsumexp": [39, 126], "max": [40, 128], "mean": [41, 130], "min": [42, 131], "ndim": 43, "prod": [44, 142], "reciproc": [45, 155], "reshap": [46, 157], "round": [47, 158], "rsqrt": [48, 159], "shape": 49, "sin": [50, 169], "size": 51, "split": [52, 152, 173], "sqrt": [53, 174], "squar": [54, 175], "sum": [55, 181], "tolist": 56, "transpos": [57, 188], "var": [58, 193], "array_equ": 59, "broadcast_to": 60, "ceil": 61, "clip": 62, "compil": [63, 321], "concaten": 64, "conv1d": [65, 214], "conv2d": [66, 215], "convolv": 67, "cosh": 69, "default_devic": 70, "default_stream": 71, "dequant": 72, "diag": 73, "diagon": 74, "disable_compil": 75, "divid": 76, "divmod": 77, "enable_compil": 78, "equal": 79, "erf": 80, "erfinv": 81, "eval": [82, 231], "expand_dim": 84, "ey": 85, "fft": [86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 207], "fft2": 87, "fftn": 88, "ifft": 89, "ifft2": 90, "ifftn": 91, "irfft": 92, "irfft2": 93, "irfftn": 94, "rfft": 95, "rfft2": 96, "rfftn": 97, "flatten": 98, "floor": 99, "floor_divid": 100, "grad": [102, 209], "greater": 103, "greater_equ": 104, "ident": [105, 265], "inner": 106, "isinf": 107, "isnan": 108, "isneginf": 109, "isposinf": 110, "jvp": 111, "less": 112, "less_equ": 113, "linalg": [114, 115], "norm": 114, "qr": 115, "linspac": 116, "log10": 119, "log2": 121, "logaddexp": 122, "logical_and": 123, "logical_not": 124, "logical_or": 125, "matmul": 127, "maximum": 129, "minimum": 132, "moveaxi": 133, "multipli": 134, "neg": 135, "new_stream": 136, "ones": 137, "ones_lik": 138, "outer": 139, "pad": 140, "partit": 141, "quantiz": 143, "quantized_matmul": 144, "random": [145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 318], "bernoulli": 145, "categor": 146, "gumbel": 147, "kei": 148, "normal": [149, 266], "randint": 150, "seed": 151, "truncated_norm": 153, "uniform": [154, 267], "repeat": 156, "save": [160, 327], "save_gguf": 161, "save_safetensor": 162, "savez": 163, "savez_compress": 164, "set_default_devic": 165, "set_default_stream": 166, "sigmoid": 167, "sign": 168, "sinh": 170, "softmax": 171, "sort": 172, "squeez": 176, "stack": 177, "stop_gradi": 178, "subtract": 180, "swapax": 182, "take": 183, "take_along_axi": 184, "tan": 185, "tanh": 186, "tensordot": 187, "tri": 189, "tril": 190, "triu": 191, "value_and_grad": [192, 199], "vjp": 194, "vmap": 195, "where": 196, "zero": 197, "zeros_lik": 198, "nn": [199, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291], "optim": [298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316], "adadelta": 299, "adafactor": 300, "adagrad": 301, "adam": 302, "adamw": 303, "adamax": 304, "lion": 305, "apply_gradi": 306, "init": [260, 261, 262, 263, 264, 265, 266, 267, 307], "state": [240, 308], "updat": [209, 245, 309, 323], "rmsprop": 310, "sgd": 311, "util": [200, 201, 202, 320], "tree_flatten": 200, "tree_map": 201, "tree_unflatten": 202, "data": 205, "type": 205, "support": 205, "algebra": 208, "neural": 209, "network": 209, "quick": [209, 326], "start": [209, 326], "The": 209, "modul": [209, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 296], "class": 209, "paramet": [209, 238], "inspect": 209, "valu": 209, "alibi": 210, "batchnorm": 213, "dropout": 216, "dropout2d": 217, "dropout3d": 218, "embed": 219, "gelu": [220, 268], "groupnorm": 221, "instancenorm": 222, "layernorm": 223, "mish": [227, 285], "appli": 228, "apply_to_modul": 229, "children": 230, "filter_and_map": 232, "freez": 233, "leaf_modul": 234, "load_weight": 235, "named_modul": 237, "save_weight": 239, "train": [241, 243, 321], "trainable_paramet": 242, "unfreez": 244, "update_modul": 246, "multiheadattent": 247, "prelu": [248, 286], "quantizedlinear": 249, "rmsnorm": 250, "relu": [251, 287], "rope": 252, "selu": [253, 288], "sequenti": 254, "silu": [255, 289], "sinusoidalpositionalencod": 256, "softshrink": [257, 290], "step": [258, 291], "constant": 260, "glorot_norm": 261, "glorot_uniform": 262, "he_norm": 263, "he_uniform": 264, "gelu_approx": 269, "gelu_fast_approx": 270, "loss": [271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 295], "binary_cross_entropi": 271, "cosine_similarity_loss": 272, "cross_entropi": 273, "gaussian_nll_loss": 274, "hinge_loss": 275, "huber_loss": 276, "kl_div_loss": 277, "l1_loss": 278, "log_cosh_loss": 279, "margin_ranking_loss": 280, "mse_loss": 281, "nll_loss": 282, "smooth_l1_loss": 283, "triplet_loss": 284, "function": [292, 295, 321, 322, 326], "initi": 293, "tree": 320, "basic": [321, 326], "speedup": 321, "debug": 321, "pure": 321, "graph": [321, 324, 326], "automat": 322, "differenti": 322, "vector": 322, "index": 323, "differ": 323, "numpi": [323, 325], "In": 323, "place": 323, "lazi": 324, "evalu": 324, "why": 324, "comput": 324, "onli": 324, "what": 324, "you": 324, "when": 324, "convers": 325, "other": 325, "framework": 325, "pytorch": 325, "jax": 325, "tensorflow": 325, "guid": 326, "serial": 327, "format": 327, "unifi": 328, "memori": 328, "A": 328, "simpl": 328, "specifi": 329, "avgpool1d": 211, "avgpool2d": 212, "maxpool1d": 225, "maxpool2d": 226, "cosine_decai": 312, "exponential_decai": 313, "step_decai": 314, "common": 315, "schedul": 317}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}})