mlx/docs/build/html/searchindex.js
2025-06-04 01:01:48 +00:00

1 line
84 KiB
JavaScript

Search.setIndex({"docnames": ["cpp/ops", "dev/extensions", "examples/linear_regression", "examples/llama-inference", "examples/mlp", "index", "install", "python/_autosummary/mlx.core.Device", "python/_autosummary/mlx.core.Dtype", "python/_autosummary/mlx.core.Stream", "python/_autosummary/mlx.core.abs", "python/_autosummary/mlx.core.add", "python/_autosummary/mlx.core.all", "python/_autosummary/mlx.core.allclose", "python/_autosummary/mlx.core.any", "python/_autosummary/mlx.core.arange", "python/_autosummary/mlx.core.arccos", "python/_autosummary/mlx.core.arccosh", "python/_autosummary/mlx.core.arcsin", "python/_autosummary/mlx.core.arcsinh", "python/_autosummary/mlx.core.arctan", "python/_autosummary/mlx.core.arctanh", "python/_autosummary/mlx.core.argmax", "python/_autosummary/mlx.core.argmin", "python/_autosummary/mlx.core.argpartition", "python/_autosummary/mlx.core.argsort", "python/_autosummary/mlx.core.array", "python/_autosummary/mlx.core.array.T", "python/_autosummary/mlx.core.array.abs", "python/_autosummary/mlx.core.array.all", "python/_autosummary/mlx.core.array.any", "python/_autosummary/mlx.core.array.argmax", "python/_autosummary/mlx.core.array.argmin", "python/_autosummary/mlx.core.array.astype", "python/_autosummary/mlx.core.array.cos", "python/_autosummary/mlx.core.array.dtype", "python/_autosummary/mlx.core.array.exp", "python/_autosummary/mlx.core.array.item", "python/_autosummary/mlx.core.array.log", "python/_autosummary/mlx.core.array.log1p", "python/_autosummary/mlx.core.array.logsumexp", "python/_autosummary/mlx.core.array.max", "python/_autosummary/mlx.core.array.mean", "python/_autosummary/mlx.core.array.min", "python/_autosummary/mlx.core.array.ndim", "python/_autosummary/mlx.core.array.prod", "python/_autosummary/mlx.core.array.reciprocal", "python/_autosummary/mlx.core.array.reshape", "python/_autosummary/mlx.core.array.rsqrt", "python/_autosummary/mlx.core.array.shape", "python/_autosummary/mlx.core.array.sin", "python/_autosummary/mlx.core.array.size", "python/_autosummary/mlx.core.array.split", "python/_autosummary/mlx.core.array.sqrt", "python/_autosummary/mlx.core.array.square", "python/_autosummary/mlx.core.array.sum", "python/_autosummary/mlx.core.array.tolist", "python/_autosummary/mlx.core.array.transpose", "python/_autosummary/mlx.core.array.var", "python/_autosummary/mlx.core.array_equal", "python/_autosummary/mlx.core.broadcast_to", "python/_autosummary/mlx.core.concatenate", "python/_autosummary/mlx.core.conv1d", "python/_autosummary/mlx.core.conv2d", "python/_autosummary/mlx.core.convolve", "python/_autosummary/mlx.core.cos", "python/_autosummary/mlx.core.cosh", "python/_autosummary/mlx.core.default_device", "python/_autosummary/mlx.core.default_stream", "python/_autosummary/mlx.core.divide", "python/_autosummary/mlx.core.equal", "python/_autosummary/mlx.core.erf", "python/_autosummary/mlx.core.erfinv", "python/_autosummary/mlx.core.eval", "python/_autosummary/mlx.core.exp", "python/_autosummary/mlx.core.expand_dims", "python/_autosummary/mlx.core.eye", "python/_autosummary/mlx.core.fft.fft", "python/_autosummary/mlx.core.fft.fft2", "python/_autosummary/mlx.core.fft.fftn", "python/_autosummary/mlx.core.fft.ifft", "python/_autosummary/mlx.core.fft.ifft2", "python/_autosummary/mlx.core.fft.ifftn", "python/_autosummary/mlx.core.fft.irfft", "python/_autosummary/mlx.core.fft.irfft2", "python/_autosummary/mlx.core.fft.irfftn", "python/_autosummary/mlx.core.fft.rfft", "python/_autosummary/mlx.core.fft.rfft2", "python/_autosummary/mlx.core.fft.rfftn", "python/_autosummary/mlx.core.full", "python/_autosummary/mlx.core.grad", "python/_autosummary/mlx.core.greater", "python/_autosummary/mlx.core.greater_equal", "python/_autosummary/mlx.core.identity", "python/_autosummary/mlx.core.jvp", "python/_autosummary/mlx.core.less", "python/_autosummary/mlx.core.less_equal", "python/_autosummary/mlx.core.load", "python/_autosummary/mlx.core.log", "python/_autosummary/mlx.core.log10", "python/_autosummary/mlx.core.log1p", "python/_autosummary/mlx.core.log2", "python/_autosummary/mlx.core.logaddexp", "python/_autosummary/mlx.core.logical_not", "python/_autosummary/mlx.core.logsumexp", "python/_autosummary/mlx.core.matmul", "python/_autosummary/mlx.core.max", "python/_autosummary/mlx.core.maximum", "python/_autosummary/mlx.core.mean", "python/_autosummary/mlx.core.min", "python/_autosummary/mlx.core.minimum", "python/_autosummary/mlx.core.multiply", "python/_autosummary/mlx.core.negative", "python/_autosummary/mlx.core.new_stream", "python/_autosummary/mlx.core.ones", "python/_autosummary/mlx.core.ones_like", "python/_autosummary/mlx.core.pad", "python/_autosummary/mlx.core.partition", "python/_autosummary/mlx.core.prod", "python/_autosummary/mlx.core.random.bernoulli", "python/_autosummary/mlx.core.random.categorical", "python/_autosummary/mlx.core.random.gumbel", "python/_autosummary/mlx.core.random.key", "python/_autosummary/mlx.core.random.normal", "python/_autosummary/mlx.core.random.randint", "python/_autosummary/mlx.core.random.seed", "python/_autosummary/mlx.core.random.split", "python/_autosummary/mlx.core.random.truncated_normal", "python/_autosummary/mlx.core.random.uniform", "python/_autosummary/mlx.core.reciprocal", "python/_autosummary/mlx.core.reshape", "python/_autosummary/mlx.core.rsqrt", "python/_autosummary/mlx.core.save", "python/_autosummary/mlx.core.savez", "python/_autosummary/mlx.core.savez_compressed", "python/_autosummary/mlx.core.set_default_device", "python/_autosummary/mlx.core.set_default_stream", "python/_autosummary/mlx.core.sigmoid", "python/_autosummary/mlx.core.sign", "python/_autosummary/mlx.core.sin", "python/_autosummary/mlx.core.sinh", "python/_autosummary/mlx.core.softmax", "python/_autosummary/mlx.core.sort", "python/_autosummary/mlx.core.split", "python/_autosummary/mlx.core.sqrt", "python/_autosummary/mlx.core.square", "python/_autosummary/mlx.core.squeeze", "python/_autosummary/mlx.core.stop_gradient", "python/_autosummary/mlx.core.subtract", "python/_autosummary/mlx.core.sum", "python/_autosummary/mlx.core.take", "python/_autosummary/mlx.core.take_along_axis", "python/_autosummary/mlx.core.tan", "python/_autosummary/mlx.core.tanh", "python/_autosummary/mlx.core.transpose", "python/_autosummary/mlx.core.value_and_grad", "python/_autosummary/mlx.core.var", "python/_autosummary/mlx.core.vjp", "python/_autosummary/mlx.core.vmap", "python/_autosummary/mlx.core.where", "python/_autosummary/mlx.core.zeros", "python/_autosummary/mlx.core.zeros_like", "python/_autosummary/mlx.nn.Conv1d", "python/_autosummary/mlx.nn.Conv2d", "python/_autosummary/mlx.nn.Embedding", "python/_autosummary/mlx.nn.GELU", "python/_autosummary/mlx.nn.GroupNorm", "python/_autosummary/mlx.nn.LayerNorm", "python/_autosummary/mlx.nn.Linear", "python/_autosummary/mlx.nn.Mish", "python/_autosummary/mlx.nn.MultiHeadAttention", "python/_autosummary/mlx.nn.PReLU", "python/_autosummary/mlx.nn.RMSNorm", "python/_autosummary/mlx.nn.ReLU", "python/_autosummary/mlx.nn.RoPE", "python/_autosummary/mlx.nn.SELU", "python/_autosummary/mlx.nn.Sequential", "python/_autosummary/mlx.nn.SiLU", "python/_autosummary/mlx.nn.Step", "python/_autosummary/mlx.nn.value_and_grad", "python/_autosummary/mlx.optimizers.Adam", "python/_autosummary/mlx.optimizers.Optimizer", "python/_autosummary/mlx.optimizers.OptimizerState", "python/_autosummary/mlx.optimizers.SGD", "python/_autosummary/mlx.utils.tree_flatten", "python/_autosummary/mlx.utils.tree_map", "python/_autosummary/mlx.utils.tree_unflatten", "python/_autosummary_functions/mlx.nn.gelu", "python/_autosummary_functions/mlx.nn.gelu_approx", "python/_autosummary_functions/mlx.nn.gelu_fast_approx", "python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy", "python/_autosummary_functions/mlx.nn.losses.cross_entropy", "python/_autosummary_functions/mlx.nn.losses.kl_div_loss", "python/_autosummary_functions/mlx.nn.losses.l1_loss", "python/_autosummary_functions/mlx.nn.losses.mse_loss", "python/_autosummary_functions/mlx.nn.losses.nll_loss", "python/_autosummary_functions/mlx.nn.mish", "python/_autosummary_functions/mlx.nn.prelu", "python/_autosummary_functions/mlx.nn.relu", "python/_autosummary_functions/mlx.nn.selu", "python/_autosummary_functions/mlx.nn.silu", "python/_autosummary_functions/mlx.nn.step", "python/array", "python/data_types", "python/devices_and_streams", "python/fft", "python/nn", "python/nn/module", "python/ops", "python/optimizers", "python/random", "python/transforms", "python/tree_utils", "quick_start", "unified_memory", "using_streams"], "filenames": ["cpp/ops.rst", "dev/extensions.rst", "examples/linear_regression.rst", "examples/llama-inference.rst", "examples/mlp.rst", "index.rst", "install.rst", "python/_autosummary/mlx.core.Device.rst", "python/_autosummary/mlx.core.Dtype.rst", "python/_autosummary/mlx.core.Stream.rst", "python/_autosummary/mlx.core.abs.rst", "python/_autosummary/mlx.core.add.rst", "python/_autosummary/mlx.core.all.rst", "python/_autosummary/mlx.core.allclose.rst", "python/_autosummary/mlx.core.any.rst", "python/_autosummary/mlx.core.arange.rst", "python/_autosummary/mlx.core.arccos.rst", "python/_autosummary/mlx.core.arccosh.rst", "python/_autosummary/mlx.core.arcsin.rst", "python/_autosummary/mlx.core.arcsinh.rst", "python/_autosummary/mlx.core.arctan.rst", "python/_autosummary/mlx.core.arctanh.rst", "python/_autosummary/mlx.core.argmax.rst", "python/_autosummary/mlx.core.argmin.rst", "python/_autosummary/mlx.core.argpartition.rst", "python/_autosummary/mlx.core.argsort.rst", "python/_autosummary/mlx.core.array.rst", "python/_autosummary/mlx.core.array.T.rst", "python/_autosummary/mlx.core.array.abs.rst", "python/_autosummary/mlx.core.array.all.rst", "python/_autosummary/mlx.core.array.any.rst", "python/_autosummary/mlx.core.array.argmax.rst", "python/_autosummary/mlx.core.array.argmin.rst", "python/_autosummary/mlx.core.array.astype.rst", "python/_autosummary/mlx.core.array.cos.rst", "python/_autosummary/mlx.core.array.dtype.rst", "python/_autosummary/mlx.core.array.exp.rst", "python/_autosummary/mlx.core.array.item.rst", "python/_autosummary/mlx.core.array.log.rst", "python/_autosummary/mlx.core.array.log1p.rst", "python/_autosummary/mlx.core.array.logsumexp.rst", "python/_autosummary/mlx.core.array.max.rst", "python/_autosummary/mlx.core.array.mean.rst", "python/_autosummary/mlx.core.array.min.rst", "python/_autosummary/mlx.core.array.ndim.rst", "python/_autosummary/mlx.core.array.prod.rst", "python/_autosummary/mlx.core.array.reciprocal.rst", "python/_autosummary/mlx.core.array.reshape.rst", "python/_autosummary/mlx.core.array.rsqrt.rst", "python/_autosummary/mlx.core.array.shape.rst", "python/_autosummary/mlx.core.array.sin.rst", "python/_autosummary/mlx.core.array.size.rst", "python/_autosummary/mlx.core.array.split.rst", "python/_autosummary/mlx.core.array.sqrt.rst", "python/_autosummary/mlx.core.array.square.rst", "python/_autosummary/mlx.core.array.sum.rst", "python/_autosummary/mlx.core.array.tolist.rst", "python/_autosummary/mlx.core.array.transpose.rst", "python/_autosummary/mlx.core.array.var.rst", "python/_autosummary/mlx.core.array_equal.rst", "python/_autosummary/mlx.core.broadcast_to.rst", "python/_autosummary/mlx.core.concatenate.rst", "python/_autosummary/mlx.core.conv1d.rst", "python/_autosummary/mlx.core.conv2d.rst", "python/_autosummary/mlx.core.convolve.rst", "python/_autosummary/mlx.core.cos.rst", "python/_autosummary/mlx.core.cosh.rst", "python/_autosummary/mlx.core.default_device.rst", "python/_autosummary/mlx.core.default_stream.rst", "python/_autosummary/mlx.core.divide.rst", "python/_autosummary/mlx.core.equal.rst", "python/_autosummary/mlx.core.erf.rst", "python/_autosummary/mlx.core.erfinv.rst", "python/_autosummary/mlx.core.eval.rst", "python/_autosummary/mlx.core.exp.rst", "python/_autosummary/mlx.core.expand_dims.rst", "python/_autosummary/mlx.core.eye.rst", "python/_autosummary/mlx.core.fft.fft.rst", "python/_autosummary/mlx.core.fft.fft2.rst", "python/_autosummary/mlx.core.fft.fftn.rst", "python/_autosummary/mlx.core.fft.ifft.rst", "python/_autosummary/mlx.core.fft.ifft2.rst", "python/_autosummary/mlx.core.fft.ifftn.rst", "python/_autosummary/mlx.core.fft.irfft.rst", "python/_autosummary/mlx.core.fft.irfft2.rst", "python/_autosummary/mlx.core.fft.irfftn.rst", "python/_autosummary/mlx.core.fft.rfft.rst", "python/_autosummary/mlx.core.fft.rfft2.rst", "python/_autosummary/mlx.core.fft.rfftn.rst", "python/_autosummary/mlx.core.full.rst", "python/_autosummary/mlx.core.grad.rst", "python/_autosummary/mlx.core.greater.rst", "python/_autosummary/mlx.core.greater_equal.rst", "python/_autosummary/mlx.core.identity.rst", "python/_autosummary/mlx.core.jvp.rst", "python/_autosummary/mlx.core.less.rst", "python/_autosummary/mlx.core.less_equal.rst", "python/_autosummary/mlx.core.load.rst", "python/_autosummary/mlx.core.log.rst", "python/_autosummary/mlx.core.log10.rst", "python/_autosummary/mlx.core.log1p.rst", "python/_autosummary/mlx.core.log2.rst", "python/_autosummary/mlx.core.logaddexp.rst", "python/_autosummary/mlx.core.logical_not.rst", "python/_autosummary/mlx.core.logsumexp.rst", "python/_autosummary/mlx.core.matmul.rst", "python/_autosummary/mlx.core.max.rst", "python/_autosummary/mlx.core.maximum.rst", "python/_autosummary/mlx.core.mean.rst", "python/_autosummary/mlx.core.min.rst", "python/_autosummary/mlx.core.minimum.rst", "python/_autosummary/mlx.core.multiply.rst", "python/_autosummary/mlx.core.negative.rst", "python/_autosummary/mlx.core.new_stream.rst", "python/_autosummary/mlx.core.ones.rst", "python/_autosummary/mlx.core.ones_like.rst", "python/_autosummary/mlx.core.pad.rst", "python/_autosummary/mlx.core.partition.rst", "python/_autosummary/mlx.core.prod.rst", "python/_autosummary/mlx.core.random.bernoulli.rst", "python/_autosummary/mlx.core.random.categorical.rst", "python/_autosummary/mlx.core.random.gumbel.rst", "python/_autosummary/mlx.core.random.key.rst", "python/_autosummary/mlx.core.random.normal.rst", "python/_autosummary/mlx.core.random.randint.rst", "python/_autosummary/mlx.core.random.seed.rst", "python/_autosummary/mlx.core.random.split.rst", "python/_autosummary/mlx.core.random.truncated_normal.rst", "python/_autosummary/mlx.core.random.uniform.rst", "python/_autosummary/mlx.core.reciprocal.rst", "python/_autosummary/mlx.core.reshape.rst", "python/_autosummary/mlx.core.rsqrt.rst", "python/_autosummary/mlx.core.save.rst", "python/_autosummary/mlx.core.savez.rst", "python/_autosummary/mlx.core.savez_compressed.rst", "python/_autosummary/mlx.core.set_default_device.rst", "python/_autosummary/mlx.core.set_default_stream.rst", "python/_autosummary/mlx.core.sigmoid.rst", "python/_autosummary/mlx.core.sign.rst", "python/_autosummary/mlx.core.sin.rst", "python/_autosummary/mlx.core.sinh.rst", "python/_autosummary/mlx.core.softmax.rst", "python/_autosummary/mlx.core.sort.rst", "python/_autosummary/mlx.core.split.rst", "python/_autosummary/mlx.core.sqrt.rst", "python/_autosummary/mlx.core.square.rst", "python/_autosummary/mlx.core.squeeze.rst", "python/_autosummary/mlx.core.stop_gradient.rst", "python/_autosummary/mlx.core.subtract.rst", "python/_autosummary/mlx.core.sum.rst", "python/_autosummary/mlx.core.take.rst", "python/_autosummary/mlx.core.take_along_axis.rst", "python/_autosummary/mlx.core.tan.rst", "python/_autosummary/mlx.core.tanh.rst", "python/_autosummary/mlx.core.transpose.rst", "python/_autosummary/mlx.core.value_and_grad.rst", "python/_autosummary/mlx.core.var.rst", "python/_autosummary/mlx.core.vjp.rst", "python/_autosummary/mlx.core.vmap.rst", "python/_autosummary/mlx.core.where.rst", "python/_autosummary/mlx.core.zeros.rst", "python/_autosummary/mlx.core.zeros_like.rst", "python/_autosummary/mlx.nn.Conv1d.rst", "python/_autosummary/mlx.nn.Conv2d.rst", "python/_autosummary/mlx.nn.Embedding.rst", "python/_autosummary/mlx.nn.GELU.rst", "python/_autosummary/mlx.nn.GroupNorm.rst", "python/_autosummary/mlx.nn.LayerNorm.rst", "python/_autosummary/mlx.nn.Linear.rst", "python/_autosummary/mlx.nn.Mish.rst", "python/_autosummary/mlx.nn.MultiHeadAttention.rst", "python/_autosummary/mlx.nn.PReLU.rst", "python/_autosummary/mlx.nn.RMSNorm.rst", "python/_autosummary/mlx.nn.ReLU.rst", "python/_autosummary/mlx.nn.RoPE.rst", "python/_autosummary/mlx.nn.SELU.rst", "python/_autosummary/mlx.nn.Sequential.rst", "python/_autosummary/mlx.nn.SiLU.rst", "python/_autosummary/mlx.nn.Step.rst", "python/_autosummary/mlx.nn.value_and_grad.rst", "python/_autosummary/mlx.optimizers.Adam.rst", "python/_autosummary/mlx.optimizers.Optimizer.rst", "python/_autosummary/mlx.optimizers.OptimizerState.rst", "python/_autosummary/mlx.optimizers.SGD.rst", "python/_autosummary/mlx.utils.tree_flatten.rst", "python/_autosummary/mlx.utils.tree_map.rst", "python/_autosummary/mlx.utils.tree_unflatten.rst", "python/_autosummary_functions/mlx.nn.gelu.rst", "python/_autosummary_functions/mlx.nn.gelu_approx.rst", "python/_autosummary_functions/mlx.nn.gelu_fast_approx.rst", "python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst", "python/_autosummary_functions/mlx.nn.losses.cross_entropy.rst", "python/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst", "python/_autosummary_functions/mlx.nn.losses.l1_loss.rst", "python/_autosummary_functions/mlx.nn.losses.mse_loss.rst", "python/_autosummary_functions/mlx.nn.losses.nll_loss.rst", "python/_autosummary_functions/mlx.nn.mish.rst", "python/_autosummary_functions/mlx.nn.prelu.rst", "python/_autosummary_functions/mlx.nn.relu.rst", "python/_autosummary_functions/mlx.nn.selu.rst", "python/_autosummary_functions/mlx.nn.silu.rst", "python/_autosummary_functions/mlx.nn.step.rst", "python/array.rst", "python/data_types.rst", "python/devices_and_streams.rst", "python/fft.rst", "python/nn.rst", "python/nn/module.rst", "python/ops.rst", "python/optimizers.rst", "python/random.rst", "python/transforms.rst", "python/tree_utils.rst", "quick_start.rst", "unified_memory.rst", "using_streams.rst"], "titles": ["Operations", "Developer Documentation", "Linear Regression", "LLM inference", "Multi-Layer Perceptron", "MLX", "Build and Install", "mlx.core.Device", "mlx.core.Dtype", "mlx.core.Stream", "mlx.core.abs", "mlx.core.add", "mlx.core.all", "mlx.core.allclose", "mlx.core.any", "mlx.core.arange", "mlx.core.arccos", "mlx.core.arccosh", "mlx.core.arcsin", "mlx.core.arcsinh", "mlx.core.arctan", "mlx.core.arctanh", "mlx.core.argmax", "mlx.core.argmin", "mlx.core.argpartition", "mlx.core.argsort", "mlx.core.array", "mlx.core.array.T", "mlx.core.array.abs", "mlx.core.array.all", "mlx.core.array.any", "mlx.core.array.argmax", "mlx.core.array.argmin", "mlx.core.array.astype", "mlx.core.array.cos", "mlx.core.array.dtype", "mlx.core.array.exp", "mlx.core.array.item", "mlx.core.array.log", "mlx.core.array.log1p", "mlx.core.array.logsumexp", "mlx.core.array.max", "mlx.core.array.mean", "mlx.core.array.min", "mlx.core.array.ndim", "mlx.core.array.prod", "mlx.core.array.reciprocal", "mlx.core.array.reshape", "mlx.core.array.rsqrt", "mlx.core.array.shape", "mlx.core.array.sin", "mlx.core.array.size", "mlx.core.array.split", "mlx.core.array.sqrt", "mlx.core.array.square", "mlx.core.array.sum", "mlx.core.array.tolist", "mlx.core.array.transpose", "mlx.core.array.var", "mlx.core.array_equal", "mlx.core.broadcast_to", "mlx.core.concatenate", "mlx.core.conv1d", "mlx.core.conv2d", "mlx.core.convolve", "mlx.core.cos", "mlx.core.cosh", "mlx.core.default_device", "mlx.core.default_stream", "mlx.core.divide", "mlx.core.equal", "mlx.core.erf", "mlx.core.erfinv", "mlx.core.eval", "mlx.core.exp", "mlx.core.expand_dims", "mlx.core.eye", "mlx.core.fft.fft", "mlx.core.fft.fft2", "mlx.core.fft.fftn", "mlx.core.fft.ifft", "mlx.core.fft.ifft2", "mlx.core.fft.ifftn", "mlx.core.fft.irfft", "mlx.core.fft.irfft2", "mlx.core.fft.irfftn", "mlx.core.fft.rfft", "mlx.core.fft.rfft2", "mlx.core.fft.rfftn", "mlx.core.full", "mlx.core.grad", "mlx.core.greater", "mlx.core.greater_equal", "mlx.core.identity", "mlx.core.jvp", "mlx.core.less", "mlx.core.less_equal", "mlx.core.load", "mlx.core.log", "mlx.core.log10", "mlx.core.log1p", "mlx.core.log2", "mlx.core.logaddexp", "mlx.core.logical_not", "mlx.core.logsumexp", "mlx.core.matmul", "mlx.core.max", "mlx.core.maximum", "mlx.core.mean", "mlx.core.min", "mlx.core.minimum", "mlx.core.multiply", "mlx.core.negative", "mlx.core.new_stream", "mlx.core.ones", "mlx.core.ones_like", "mlx.core.pad", "mlx.core.partition", "mlx.core.prod", "mlx.core.random.bernoulli", "mlx.core.random.categorical", "mlx.core.random.gumbel", "mlx.core.random.key", "mlx.core.random.normal", "mlx.core.random.randint", "mlx.core.random.seed", "mlx.core.random.split", "mlx.core.random.truncated_normal", "mlx.core.random.uniform", "mlx.core.reciprocal", "mlx.core.reshape", "mlx.core.rsqrt", "mlx.core.save", "mlx.core.savez", "mlx.core.savez_compressed", "mlx.core.set_default_device", "mlx.core.set_default_stream", "mlx.core.sigmoid", "mlx.core.sign", "mlx.core.sin", "mlx.core.sinh", "mlx.core.softmax", "mlx.core.sort", "mlx.core.split", "mlx.core.sqrt", "mlx.core.square", "mlx.core.squeeze", "mlx.core.stop_gradient", "mlx.core.subtract", "mlx.core.sum", "mlx.core.take", "mlx.core.take_along_axis", "mlx.core.tan", "mlx.core.tanh", "mlx.core.transpose", "mlx.core.value_and_grad", "mlx.core.var", "mlx.core.vjp", "mlx.core.vmap", "mlx.core.where", "mlx.core.zeros", "mlx.core.zeros_like", "mlx.nn.Conv1d", "mlx.nn.Conv2d", "mlx.nn.Embedding", "mlx.nn.GELU", "mlx.nn.GroupNorm", "mlx.nn.LayerNorm", "mlx.nn.Linear", "mlx.nn.Mish", "mlx.nn.MultiHeadAttention", "mlx.nn.PReLU", "mlx.nn.RMSNorm", "mlx.nn.ReLU", "mlx.nn.RoPE", "mlx.nn.SELU", "mlx.nn.Sequential", "mlx.nn.SiLU", "mlx.nn.Step", "mlx.nn.value_and_grad", "mlx.optimizers.Adam", "mlx.optimizers.Optimizer", "mlx.optimizers.OptimizerState", "mlx.optimizers.SGD", "mlx.utils.tree_flatten", "mlx.utils.tree_map", "mlx.utils.tree_unflatten", "mlx.nn.gelu", "mlx.nn.gelu_approx", "mlx.nn.gelu_fast_approx", "mlx.nn.losses.binary_cross_entropy", "mlx.nn.losses.cross_entropy", "mlx.nn.losses.kl_div_loss", "mlx.nn.losses.l1_loss", "mlx.nn.losses.mse_loss", "mlx.nn.losses.nll_loss", "mlx.nn.mish", "mlx.nn.prelu", "mlx.nn.relu", "mlx.nn.selu", "mlx.nn.silu", "mlx.nn.step", "Array", "Data Types", "Devices and Streams", "FFT", "Neural Networks", "mlx.nn.Module", "Operations", "Optimizers", "Random", "Transforms", "Tree Utils", "Quick Start Guide", "Unified Memory", "Using Streams"], "terms": {"mlx": [1, 2, 3, 4, 6, 206, 209, 210, 212, 213, 214], "provid": [1, 3, 90, 155, 170, 185, 206, 207, 215], "open": [1, 15, 124, 128], "flexibl": [1, 5], "which": [1, 3, 4, 5, 6, 15, 33, 73, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 90, 94, 97, 120, 121, 130, 132, 133, 134, 146, 150, 155, 157, 158, 165, 174, 191, 207, 210, 214, 215], "user": [1, 3, 206], "mai": 1, "add": [1, 3, 75, 102, 116, 162, 163, 214], "special": 1, "without": [1, 3, 5, 147, 170, 206, 212, 214], "much": [1, 3], "hassl": 1, "while": [1, 3, 6, 130, 174], "librari": [1, 6, 206], "suppli": 1, "effici": [1, 3, 5, 174, 213], "can": [1, 3, 5, 6, 11, 15, 47, 57, 69, 70, 73, 91, 92, 95, 96, 102, 107, 110, 111, 119, 120, 124, 127, 128, 148, 155, 164, 176, 206, 207, 209, 210, 212, 213, 214, 215], "compos": [1, 5, 206, 213], "ani": [1, 3, 5, 15, 165, 184, 185, 186, 206, 207, 212, 213, 214], "number": [1, 15, 51, 63, 76, 90, 93, 94, 116, 120, 123, 126, 128, 155, 157, 158, 162, 163, 166, 170, 210, 215], "applic": 1, "aris": 1, "case": [1, 3, 79, 82, 83, 85, 86, 87, 88, 105, 130, 146, 175, 178, 199, 201, 213, 214, 215], "where": [1, 4, 76, 155, 158, 162, 163, 165, 166, 167, 172, 175, 177, 178, 187, 188, 189, 199, 200, 201, 207], "new": [1, 4, 60, 130, 154, 170, 185, 207, 209], "function": [1, 2, 3, 4, 5, 13, 71, 72, 73, 90, 94, 105, 137, 155, 157, 158, 165, 169, 176, 178, 179, 185, 187, 188, 189, 196, 197, 201, 207, 209, 210, 212], "highli": [1, 6], "optim": [1, 2, 4, 5, 207], "ar": [1, 2, 3, 4, 5, 6, 13, 15, 59, 60, 64, 76, 78, 79, 81, 82, 84, 85, 87, 88, 90, 94, 105, 116, 117, 119, 120, 121, 124, 127, 128, 133, 134, 146, 150, 155, 157, 158, 162, 163, 166, 167, 170, 184, 185, 206, 207, 212, 213, 214], "need": [1, 3, 4, 5, 59, 206, 207, 210, 213, 214], "For": [1, 3, 6, 186, 207, 210, 213, 214], "you": [1, 3, 5, 6, 210, 214], "design": [1, 2, 5, 210, 214], "your": [1, 3, 6, 207], "own": [1, 6], "link": [1, 6], "top": 1, "core": [1, 2, 3, 4, 190, 206, 207, 209, 213], "we": [1, 2, 3, 4, 164, 176, 206, 210, 212, 214], "inner": 1, "work": [1, 3, 6], "go": [1, 3], "over": [1, 3, 4, 12, 14, 22, 23, 24, 25, 62, 63, 79, 82, 85, 88, 104, 106, 108, 109, 117, 118, 131, 141, 142, 149, 156, 162, 163, 166, 167, 172, 191], "simpl": [1, 3, 4, 164, 206], "learn": [1, 2, 4, 5, 166, 167, 172, 183], "step": [1, 3, 4, 15], "involv": [1, 209], "ad": [1, 2, 6, 207], "let": [1, 2, 3], "s": [1, 2, 3, 4, 35, 44, 78, 79, 81, 82, 84, 85, 87, 88, 90, 97, 108, 120, 155, 156, 158, 179, 181, 206, 207, 209, 210, 213, 214], "sai": [1, 3], "would": [1, 3, 214], "like": [1, 3, 5, 115, 161, 213, 214], "an": [1, 3, 4, 6, 8, 12, 14, 26, 60, 62, 63, 73, 76, 89, 93, 106, 109, 114, 115, 116, 118, 130, 143, 146, 150, 151, 158, 160, 161, 166, 167, 168, 170, 181, 182, 184, 188, 197, 206, 207, 210, 212, 213, 214, 215], "take": [1, 3, 4, 90, 94, 107, 110, 115, 151, 155, 157, 158, 161, 210, 214, 215], "two": [1, 11, 13, 59, 69, 70, 78, 81, 87, 91, 92, 95, 96, 102, 105, 107, 110, 111, 214], "arrai": [1, 3, 4, 5, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 126, 127, 128, 129, 130, 131, 132, 133, 134, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 190, 191, 192, 193, 194, 195, 196, 197, 201, 206, 207, 213, 214], "x": [1, 2, 3, 4, 71, 93, 121, 133, 137, 159, 165, 166, 167, 169, 172, 173, 175, 177, 178, 185, 187, 188, 189, 196, 197, 198, 199, 200, 201, 206, 207, 209, 213, 214], "y": [1, 2, 3, 4, 159, 166, 167, 172, 206, 209], "scale": [1, 3, 170, 175, 199], "them": [1, 3, 206, 207, 214], "both": [1, 11, 69, 70, 91, 92, 95, 96, 102, 107, 110, 111, 120, 148, 209, 213, 214], "some": [1, 2, 3, 4, 207], "coeffic": 1, "alpha": [1, 175, 197, 199], "beta": [1, 166, 167, 180], "respect": [1, 2, 4, 90, 155, 165, 166, 167, 185, 206, 207, 213], "togeth": [1, 4, 185], "get": [1, 2, 4, 63, 122, 182, 214], "z": 1, "well": [1, 3, 170, 206, 207], "veri": [1, 3, 170, 214], "easili": 1, "do": [1, 3, 6, 207], "just": [1, 4], "write": [1, 3, 206], "out": [1, 6], "follow": [1, 3, 4, 5, 6, 15, 64, 180, 183, 188, 189, 192, 210, 214], "import": [1, 2, 3, 4, 6, 133, 155, 184, 185, 186, 190, 206, 207, 213], "mx": [1, 2, 3, 4, 133, 155, 173, 190, 191, 192, 193, 194, 195, 198, 206, 207, 209, 210, 213, 214, 215], "def": [1, 2, 3, 4, 155, 206, 207, 214], "simple_axpbi": 1, "float": [1, 13, 15, 56, 89, 119, 124, 127, 128, 166, 167, 172, 178, 180, 183, 201, 203, 207], "return": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 37, 49, 56, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 126, 127, 128, 129, 130, 131, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 179, 184, 185, 186, 190, 191, 192, 193, 194, 195, 206, 207, 212, 214], "thi": [1, 3, 4, 6, 12, 13, 14, 15, 22, 23, 24, 25, 73, 94, 102, 104, 105, 106, 108, 109, 117, 118, 120, 141, 142, 143, 149, 150, 156, 178, 188, 189, 201, 206, 207, 212], "perform": [1, 3, 5, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 105, 141, 150, 166, 206, 214], "leav": [1, 185], "differenti": [1, 5], "howev": [1, 165, 166, 206, 210], "vector": [1, 2, 5, 94, 150, 157, 158, 164, 213], "math": [1, 3], "often": 1, "realiz": 1, "axpbi": 1, "routin": 1, "defin": [1, 2, 3, 4, 6, 182, 184], "same": [1, 3, 6, 59, 60, 63, 64, 83, 86, 87, 88, 90, 94, 116, 120, 157, 159, 166, 206, 207, 210, 214], "realli": 1, "part": 1, "doe": [1, 3, 6, 206], "fast": [1, 165, 189, 214], "so": [1, 3, 6, 90, 155, 209, 214], "decid": [1, 207], "want": [1, 3, 214], "reli": 1, "acceler": 1, "framework": [1, 5], "continu": 1, "impos": 1, "our": [1, 3, 176, 180], "assumpt": 1, "also": [1, 3, 4, 5, 11, 69, 70, 79, 82, 85, 88, 91, 92, 95, 96, 102, 107, 110, 111, 148, 170, 175, 177, 179, 182, 187, 199, 200, 206, 207, 209, 213, 215], "assum": [1, 3, 166, 185, 206], "how": [1, 3, 4, 162, 163, 164, 170, 206, 214], "gradient": [1, 2, 4, 90, 147, 155, 179, 183, 206, 207, 209, 213], "ins": 1, "what": [1, 3], "coincid": 1, "right": [1, 165, 188, 189], "place": [1, 3], "cours": 1, "The": [1, 3, 4, 5, 6, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 35, 44, 49, 56, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 126, 127, 128, 129, 130, 137, 138, 139, 140, 141, 142, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 166, 167, 168, 170, 172, 174, 176, 178, 179, 181, 182, 183, 184, 185, 186, 190, 191, 192, 193, 194, 195, 201, 203, 207, 209, 213, 214, 215], "structur": [1, 73], "from": [1, 3, 4, 5, 84, 85, 87, 88, 89, 97, 105, 115, 119, 120, 121, 122, 124, 127, 133, 146, 147, 148, 150, 151, 159, 161, 170, 184, 185, 186, 206, 207, 212, 213, 214], "frontend": 1, "api": 1, "redirect": 1, "when": [1, 3, 5, 162, 163, 192, 206, 207, 210, 214], "appropri": 1, "fallback": 1, "metal": [1, 6], "vjp": [1, 213], "jvp": [1, 213], "In": [1, 3, 4, 105, 166, 180, 185, 206, 207, 212, 214], "one": [1, 3, 6, 56, 63, 75, 76, 100, 105, 120, 146, 148, 207, 214], "sentenc": 1, "comput": [1, 2, 3, 4, 5, 6, 90, 94, 102, 108, 141, 147, 155, 156, 157, 166, 167, 172, 179, 188, 189, 190, 191, 192, 193, 194, 195, 206, 207, 209, 213, 214], "graph": [1, 3, 4, 5, 73, 132], "rule": 1, "evalu": [1, 3, 4, 73, 94, 132, 157, 206, 207, 209, 213], "said": [1, 3], "start": [1, 2, 3, 5, 6, 15, 143, 214], "discuss": 1, "more": [1, 4, 8, 56, 105, 206, 210, 214], "detail": [1, 8, 180, 206], "thei": [1, 2, 3, 64, 176, 207, 212, 213, 214], "c": [1, 3, 162, 163, 203, 213, 214], "scalar": [1, 11, 13, 26, 37, 56, 59, 60, 69, 70, 89, 90, 91, 92, 95, 96, 102, 103, 105, 107, 110, 111, 116, 124, 127, 128, 148, 155, 159, 179, 213], "sum": [1, 2, 11, 104, 141, 190, 191, 192, 193, 194, 195], "elementwis": 1, "numpi": [1, 3, 4, 5, 11, 13, 15, 60, 69, 70, 91, 92, 95, 96, 102, 105, 107, 110, 111, 148, 213], "style": [1, 11, 13, 69, 70, 91, 92, 95, 96, 102, 105, 107, 110, 111, 148], "broadcast": [1, 11, 13, 60, 69, 70, 89, 91, 92, 95, 96, 102, 105, 107, 110, 111, 119, 120, 127, 128, 148, 151, 159, 170], "between": [1, 5, 190, 191, 192, 193, 194, 195, 214], "input": [1, 2, 3, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 74, 75, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 90, 91, 92, 94, 95, 96, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 115, 116, 117, 118, 126, 129, 130, 131, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 158, 159, 161, 162, 163, 164, 166, 167, 168, 170, 172, 174, 178, 190, 192, 195, 201, 213], "upcast": 1, "const": 1, "factor": 1, "streamordevic": 1, "stream": [1, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 123, 124, 126, 127, 128, 129, 130, 131, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 156, 159, 160, 161, 214], "schedul": [1, 214], "itself": 1, "call": [1, 3, 4, 27, 164, 176, 206, 207, 209], "other": [1, 3, 170, 206, 207, 213], "within": [1, 24], "simplest": 1, "wai": [1, 3, 6, 206], "about": [1, 3, 4, 214], "term": 1, "exist": [1, 3, 207], "auto": 1, "ax": [1, 12, 14, 22, 23, 57, 75, 78, 79, 81, 82, 84, 85, 87, 88, 104, 106, 108, 109, 116, 118, 141, 146, 149, 154, 156], "multipli": 1, "earlier": 1, "goal": 1, "themselv": 1, "contain": [1, 3, 49, 73, 83, 84, 85, 103, 143, 159, 206, 207], "act": 1, "data": [1, 4, 5, 8, 15, 76, 86, 87, 89, 93, 114, 127, 160], "nor": [1, 90, 155], "rather": [1, 214], "easi": [1, 206], "interfac": 1, "block": [1, 3], "A": [1, 3, 5, 6, 49, 59, 90, 94, 104, 105, 119, 120, 121, 123, 124, 127, 128, 143, 155, 157, 158, 166, 167, 169, 172, 176, 179, 180, 184, 185, 186, 189, 196, 206, 207, 209], "It": [1, 3, 6, 90, 155, 181, 206], "creat": [1, 3, 6, 76, 93, 206, 207, 209], "output": [1, 3, 6, 12, 13, 14, 15, 22, 23, 24, 60, 76, 83, 86, 87, 88, 89, 90, 93, 104, 106, 108, 109, 114, 115, 117, 118, 119, 120, 121, 123, 124, 127, 128, 133, 134, 141, 146, 149, 151, 155, 156, 157, 158, 159, 160, 161, 162, 163, 168, 170, 178, 190, 191, 192, 193, 194, 195, 201, 213, 214], "given": [1, 12, 14, 24, 60, 61, 73, 75, 77, 78, 79, 80, 81, 82, 86, 87, 88, 89, 104, 106, 108, 109, 118, 124, 141, 143, 149, 156, 170, 207], "set": [1, 3, 4, 165, 168, 174, 178, 182, 201, 207, 210], "further": [1, 6], "class": [1, 3, 4, 7, 8, 9, 26, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 207], "under": 1, "These": [1, 151, 214], "word": 1, "bit": [1, 203, 207], "abstract": 1, "back": [1, 3], "give": [1, 3, 4, 24], "ourselv": 1, "concret": [1, 214], "imag": [1, 163], "public": [1, 206], "explicit": [1, 210], "alpha_": 1, "beta_": 1, "must": [1, 6, 73, 89, 119, 120, 124, 127, 128, 159], "know": [1, 3], "popul": 1, "To": [1, 2, 3, 6, 213], "avoid": 1, "unecessari": [], "alloc": [1, 207], "respons": 1, "space": [1, 195], "void": 1, "eval_cpu": 1, "std": 1, "overrid": 1, "eval_gpu": 1, "jacobian": [1, 94, 157, 213], "product": [1, 94, 105, 118, 157, 170, 213], "primal": [1, 94, 157], "tangent": [1, 20, 21, 94, 152, 153], "int": [1, 3, 4, 7, 9, 12, 14, 15, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 49, 52, 55, 56, 58, 60, 61, 62, 63, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 93, 104, 106, 108, 109, 114, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 130, 141, 142, 143, 146, 149, 150, 151, 154, 155, 156, 158, 160, 162, 163, 164, 166, 167, 168, 170, 172, 174, 191, 192, 195, 206, 207], "argnum": [1, 90, 155], "cotan": 1, "accross": [1, 166], "pair": [1, 116, 174], "repres": [1, 3], "axi": [1, 3, 4, 12, 14, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 52, 55, 58, 61, 75, 77, 80, 83, 84, 85, 86, 87, 88, 104, 106, 108, 109, 116, 117, 118, 120, 141, 142, 143, 146, 149, 150, 151, 154, 156, 158, 191, 192, 195], "correspond": [1, 12, 14, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 104, 106, 109, 118, 149, 158, 185], "dimens": [1, 3, 12, 14, 22, 23, 44, 49, 56, 63, 75, 84, 85, 87, 88, 104, 105, 106, 108, 109, 118, 120, 126, 149, 151, 154, 156, 162, 163, 166, 167, 170, 172, 174], "vmap": [1, 213], "print": [1, 2, 3, 4, 6, 184, 185, 186, 206, 210, 213], "ostream": 1, "os": [1, 6], "equival": [1, 27, 47, 57, 165], "check": [1, 59], "bool": [1, 12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 55, 56, 58, 59, 73, 104, 106, 108, 109, 118, 119, 124, 127, 128, 132, 149, 156, 162, 163, 166, 167, 168, 170, 174, 183, 207], "is_equival": 1, "privat": 1, "fall": 1, "eval": [1, 2, 3, 4, 206, 207, 209, 213], "deriv": 1, "base": [1, 73, 99, 101, 181, 207, 209, 210], "abov": [1, 3, 6, 214], "demonstr": 1, "treat": [1, 59, 84, 85, 87, 88, 150], "paramet": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 170, 172, 174, 176, 178, 179, 181, 183, 184, 185, 186, 190, 191, 192, 193, 194, 195, 201, 207, 209], "produc": [1, 170], "through": [1, 147], "construct": [1, 4, 89, 114, 160], "its": [1, 6, 105, 117, 126, 179, 186, 206, 214], "type": [1, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 49, 56, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 126, 127, 128, 129, 130, 131, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 181, 184, 190, 191, 192, 193, 194, 195, 206], "shape": [1, 3, 4, 47, 59, 60, 62, 63, 77, 80, 83, 86, 87, 88, 89, 94, 105, 114, 115, 119, 120, 121, 123, 124, 127, 128, 130, 151, 157, 159, 160, 161, 162, 163, 206, 209, 213, 214], "pass": [1, 3, 4, 47, 57, 116, 155, 176, 179, 184, 206, 207], "re": [1, 4], "now": [1, 3], "promot": 1, "dtype": [1, 3, 15, 26, 33, 56, 76, 89, 93, 114, 121, 123, 124, 127, 128, 160, 203, 213], "promoted_dtyp": 1, "promote_typ": 1, "float32": [1, 15, 76, 93, 114, 121, 123, 127, 128, 160, 203, 213], "non": [1, 6, 169, 196, 207], "point": [1, 2, 3, 6, 203], "out_dtyp": 1, "is_floating_point": 1, "cast": [1, 33, 86, 87, 88, 207], "up": [1, 3], "determin": 1, "x_cast": 1, "astyp": [1, 3, 207], "y_cast": 1, "broadcasted_input": 1, "broadcast_arrai": 1, "out_shap": 1, "0": [1, 2, 3, 4, 6, 7, 15, 52, 58, 61, 62, 63, 76, 90, 116, 119, 128, 143, 155, 156, 158, 162, 163, 165, 166, 167, 171, 173, 175, 178, 180, 183, 184, 188, 189, 190, 197, 198, 199, 201, 206, 207, 210, 213], "unique_ptr": 1, "make_uniqu": 1, "to_stream": 1, "handl": [1, 206], "resolv": 1, "No": [1, 3], "happen": [1, 3, 209], "alon": 1, "effect": 1, "onli": [1, 3, 5, 6, 59, 62, 63, 203, 206, 207, 214], "execut": [1, 6, 214], "depend": [1, 2, 56, 214], "devic": [1, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 123, 124, 126, 127, 128, 129, 130, 131, 135, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 156, 159, 160, 161, 214, 215], "specifi": [1, 15, 33, 63, 84, 85, 89, 90, 114, 120, 150, 151, 154, 155, 158, 160, 178, 190, 191, 192, 193, 194, 195, 201, 214], "memori": [1, 5, 207], "ha": [1, 3, 4, 5, 56, 83, 84, 86, 87, 88, 90, 120, 207, 209, 213, 214], "been": [1, 3], "try": 1, "naiv": 1, "gener": [1, 2, 15, 76, 84, 85, 119, 123, 124, 127, 128, 210, 215], "version": [1, 6, 102, 104, 141, 158, 210], "declar": 1, "member": [1, 206, 207], "method": [1, 3, 7, 8, 9, 26, 180, 181, 182, 183], "each": [1, 49, 73, 105, 116, 120, 133, 134, 143, 154, 158, 159, 164, 166, 210], "element": [1, 10, 11, 16, 17, 18, 19, 20, 21, 24, 65, 66, 69, 70, 71, 72, 74, 76, 91, 92, 95, 96, 98, 99, 100, 101, 102, 103, 107, 110, 111, 112, 117, 129, 131, 137, 138, 139, 140, 144, 145, 148, 150, 152, 153, 155, 159, 169, 174, 177, 196, 197, 200], "find": [1, 2, 6], "pointwis": 1, "captur": [1, 206], "templat": 1, "axpby_impl": 1, "typenam": 1, "t": [1, 3, 71, 155, 180, 183, 206, 214], "readi": 1, "fill": [1, 89, 115, 161], "malloc_or_wait": 1, "synchron": 1, "avail": [1, 2, 3, 4, 6, 8, 203, 214], "There": [1, 206], "wait": [1, 3], "here": [1, 3, 197, 214], "request": 1, "pressur": 1, "condit": [1, 159, 214], "set_data": 1, "nbyte": 1, "collect": [1, 182, 185, 212], "pointer": 1, "x_ptr": 1, "y_ptr": 1, "out_ptr": 1, "relev": 1, "static_cast": 1, "size_t": 1, "out_idx": 1, "size": [1, 3, 4, 49, 63, 75, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 93, 120, 130, 143, 146, 162, 163, 164], "map": [1, 4, 97, 164, 185, 207], "linear": [1, 3, 4, 5, 165, 173, 175, 177, 185, 187, 188, 189, 198, 199, 200, 206, 207], "indic": [1, 13, 22, 23, 24, 25, 73, 90, 143, 150, 151, 155], "offset": [1, 3], "x_offset": 1, "elem_to_loc": 1, "stride": [1, 62, 63, 162, 163, 174], "y_offset": 1, "contigu": 1, "regularli": 1, "default": [1, 6, 12, 14, 15, 22, 23, 24, 25, 59, 61, 62, 63, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 90, 93, 104, 106, 108, 109, 114, 117, 118, 119, 120, 121, 123, 124, 126, 127, 128, 130, 132, 142, 143, 146, 149, 154, 155, 156, 158, 160, 162, 163, 170, 174, 182, 183, 190, 191, 192, 193, 194, 195, 203, 207, 210, 212, 215], "row": [1, 76, 93], "major": 1, "henc": 1, "doesn": [1, 206], "additon": 1, "abl": 1, "all": [1, 4, 6, 13, 24, 63, 73, 76, 79, 82, 85, 88, 105, 116, 117, 146, 170, 181, 206, 207, 210, 213, 215], "incom": 1, "accordingli": 1, "dispatch": 1, "float16": [1, 203, 207], "bfloat16": 1, "complex64": 1, "throw": 1, "error": [1, 71, 72, 143, 165, 187, 188, 189, 194], "encount": 1, "unexpect": [1, 15], "regist": [1, 4], "op": 1, "contruct": 1, "assert": 1, "2": [1, 2, 3, 4, 63, 71, 78, 81, 83, 84, 85, 86, 87, 88, 101, 105, 126, 163, 165, 172, 180, 188, 190, 203, 206, 207, 213, 214], "1": [1, 3, 4, 15, 24, 25, 62, 63, 77, 78, 80, 81, 83, 84, 85, 86, 87, 88, 105, 117, 120, 128, 137, 142, 150, 155, 162, 163, 165, 166, 167, 171, 172, 174, 175, 178, 180, 183, 188, 189, 190, 191, 192, 195, 199, 201, 203, 207, 209, 213, 214], "correct": [1, 180], "els": [1, 3, 206, 207], "float16_t": 1, "bfloat16_t": 1, "complex64_t": 1, "runtime_error": 1, "support": [1, 3, 5, 6, 13, 62, 63, 105], "have": [1, 3, 6, 59, 84, 85, 87, 88, 105, 120, 170, 176, 184, 212, 214], "rememb": 1, "3": [1, 3, 6, 190, 210, 213], "complic": 1, "keep": [1, 12, 14, 22, 23, 104, 106, 108, 109, 118, 149, 156, 206, 207], "mind": [1, 3], "half": [1, 15, 124, 128, 174], "precis": [1, 3, 165, 206], "direct": [1, 3, 207, 214], "fix": [1, 3], "possibl": [1, 3, 105, 143, 164, 214], "due": 1, "transpos": [1, 3, 27], "aren": 1, "guarante": 1, "fit": [1, 214], "requir": [1, 3, 206], "column": [1, 76, 93], "inplac": 1, "expect": [1, 3, 162, 163, 170], "answer": 1, "copi": [1, 3, 5, 117, 142], "simpli": [1, 3, 6, 173, 198, 207], "catlas_saxpbi": 1, "axpby_impl_acceler": 1, "first": [1, 2, 3, 4, 6, 90, 105, 117, 126, 155, 166, 180, 184, 206, 214], "mode": [1, 64], "i": [1, 3, 94, 162, 163, 206], "e": [1, 4, 6, 71, 94, 137, 162, 163, 166, 167, 172, 206, 209, 215], "match": [1, 6, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88], "transposit": 1, "data_s": 1, "items": 1, "flag": 1, "copy_inplac": 1, "copytyp": 1, "n": [1, 3, 26, 62, 63, 76, 77, 79, 80, 82, 83, 86, 88, 93, 156, 162, 163], "incx": 1, "inci": 1, "great": 1, "But": [1, 214], "criteria": 1, "luckili": 1, "alwai": [1, 184], "With": 1, "final": [1, 2, 3, 4], "singl": [1, 4, 73, 94, 116, 157], "row_contigu": 1, "col_contigu": 1, "common": 1, "hit": 1, "mileston": 1, "enough": 1, "run": [1, 3, 4, 5, 6, 207, 214, 215], "If": [1, 3, 6, 12, 14, 15, 22, 23, 24, 25, 56, 59, 61, 64, 73, 86, 87, 88, 89, 90, 104, 105, 106, 108, 109, 114, 116, 117, 118, 120, 141, 142, 143, 149, 150, 151, 155, 156, 158, 160, 162, 163, 166, 167, 168, 170, 174, 176, 185, 207, 214, 215], "plan": 1, "stop": [1, 3, 15, 147], "enjoi": 1, "speed": 1, "appl": [1, 3, 5, 6, 214], "silicon": [1, 3, 5, 6, 214], "address": 1, "shade": 1, "languag": [1, 203], "kernel": [1, 62, 63], "written": 1, "help": [1, 3, 214], "resourc": 1, "walkthrough": 1, "pipelin": 1, "specif": [1, 6], "cpp": 1, "algorithm": 1, "launch": 1, "exactli": [1, 3], "mani": [1, 143, 162, 163, 164, 170], "thread": 1, "pick": 1, "updat": [1, 2, 3, 4, 183, 185, 207, 209], "assign": [1, 207], "axpby_gener": 1, "buffer": 1, "constant": [1, 3, 6, 116, 166, 167, 172], "4": [1, 3, 133, 190, 203, 213, 214], "5": [1, 2, 3, 6, 119], "x_stride": 1, "6": [1, 3, 133, 188, 189, 213], "y_stride": 1, "7": [1, 3], "ndim": 1, "8": [1, 3, 6, 203, 213, 214], "uint": 1, "index": [1, 7, 9, 24, 75, 76, 90, 117, 150, 151, 155], "thread_position_in_grid": 1, "convert": [1, 56, 213], "instanti": [1, 4], "uniqu": [1, 210], "host": 1, "name": [1, 97, 133, 134, 166, 182, 206, 207], "identifi": [1, 184, 212], "instantiate_axpbi": 1, "type_nam": 1, "host_nam": 1, "axpby_general_": 1, "bflot16": 1, "compil": [1, 6], "mlx_ext": 1, "metallib": [1, 6], "see": [1, 3, 4, 8, 28, 29, 30, 31, 32, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 165, 175, 187, 188, 189, 199, 214], "later": [1, 6], "co": 1, "locat": [1, 207, 214], "share": [1, 5], "register_librari": 1, "potenti": 1, "path": [1, 6, 133, 134], "tri": 1, "load": [1, 4, 207], "hasn": 1, "alreadi": [1, 3], "static": [1, 6], "object": [1, 8, 26, 37, 56, 119, 124, 127, 128, 158, 184, 212], "why": [1, 3], "packag": [1, 2, 4], "process": [1, 3, 64, 164, 185, 212], "logic": [1, 103], "grid": 1, "shown": 1, "below": [1, 203], "prepar": [1, 3], "carri": 1, "should": [1, 2, 3, 4, 6, 73, 94, 151, 155, 157, 162, 163, 170, 176, 184, 206, 207, 212, 215], "d": [1, 3, 105, 150, 180, 186, 214], "ostringstream": 1, "kname": 1, "axpby_": 1, "general_": 1, "type_to_nam": 1, "make": [1, 3, 4, 6, 105, 206, 213, 214], "sure": [1, 3, 6, 206], "look": [1, 3], "folder": 1, "get_colocated_mtllib_path": 1, "get_kernel": 1, "str": [1, 64, 90, 97, 132, 133, 134, 155, 184, 186, 190, 191, 192, 193, 194, 195, 207], "encod": [1, 174], "compute_encod": 1, "get_command_encod": 1, "setcomputepipelinest": 1, "those": [1, 3, 206], "decelar": 1, "nelem": 1, "set_array_buff": 1, "setbyt": 1, "sizeof": 1, "threadgroup": 1, "higher": 1, "than": [1, 3, 56, 64, 91, 92, 95, 96, 105, 174, 178, 185, 201, 214], "max": [1, 107, 197, 214], "allow": [1, 181, 206, 207, 213], "tgp_size": 1, "min": [1, 110, 197], "maxtotalthreadsperthreadgroup": 1, "3d": 1, "mtl": 1, "group_dim": 1, "grid_dim": 1, "divd": 1, "among": 1, "dispatchthread": 1, "few": [1, 3, 4, 5, 213], "thing": [1, 3], "note": [1, 3, 6, 13, 62, 63, 84, 85, 120, 206], "befor": [1, 3, 6, 24, 117, 132, 207], "move": [1, 214], "track": [1, 206], "activ": [1, 169, 178, 196, 201, 206], "command": [1, 6], "instead": [1, 206], "end_encod": 1, "end": [1, 175, 178, 199, 201], "until": [1, 213], "limit": 1, "flush": 1, "enqueu": 1, "commit": 1, "associ": [1, 133, 134], "suggest": 1, "deeper": 1, "dive": 1, "studi": 1, "come": [1, 3], "far": [1, 209], "built": [1, 6], "includ": [1, 207, 213, 215], "forward": [1, 155], "diff": 1, "push": 1, "along": [1, 22, 23, 61, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 141, 143, 150, 151], "primtiv": 1, "similarli": [1, 6, 105], "scale_arr": 1, "contribut": 1, "tangent_x": 1, "tangent_i": 1, "revers": [1, 154], "arg": [1, 3, 8, 47, 57, 73, 133, 134], "push_back": 1, "fulli": [1, 5, 214], "primitv": 1, "overal": 1, "directori": [1, 3, 6], "extens": [1, 203], "h": [1, 62, 63, 163], "mlx_sample_extens": 1, "__init__": [1, 3, 4, 7, 8, 9, 26, 206, 207], "py": [1, 3], "cmakelist": 1, "txt": 1, "setup": [1, 2, 4], "strucutr": 1, "hold": [1, 3, 8, 181], "instal": 1, "pybind11": [1, 6], "sinc": [1, 3, 4, 207, 214], "compon": [1, 3], "etc": [1, 206], "becom": 1, "pybind11_modul": 1, "m": [1, 6, 76], "doc": [1, 4], "sampl": [1, 2, 3, 119, 120, 121, 124, 127, 128, 210], "_a": 1, "pos_onli": 1, "kw_onli": 1, "none": [1, 3, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 123, 124, 125, 126, 127, 128, 129, 130, 131, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 158, 159, 160, 161, 165, 170, 184, 190, 191, 192, 193, 194, 195, 207], "r": [1, 3, 155], "pbdoc": 1, "most": [1, 120, 206], "complex": [1, 84, 85, 86, 87, 88, 119, 124, 127, 128, 184, 206], "addit": [1, 3, 11, 166, 167, 170, 172, 207], "bell": 1, "whistl": 1, "liter": 1, "string": 1, "modul": [1, 3, 4, 176, 179, 212], "ensur": 1, "caster": 1, "find_packag": 1, "config": 1, "add_librari": 1, "sourc": [1, 154], "target_sourc": 1, "cmake_current_list_dir": 1, "header": 1, "target_include_directori": 1, "target_link_librari": 1, "attach": 1, "conveni": [1, 4], "mlx_build_metallib": 1, "target": [1, 155, 190, 191, 192, 193, 194, 195], "destin": 1, "automat": [1, 5, 213, 214], "practic": 1, "mlx_build_met": [1, 6], "mlx_ext_metallib": 1, "titl": 1, "include_dir": 1, "project_source_dir": 1, "mlx_include_dir": 1, "output_directori": 1, "cmake_library_output_directori": 1, "add_depend": 1, "endif": 1, "pybind11_add_modul": 1, "build_shared_lib": 1, "target_link_opt": 1, "wl": 1, "rpath": 1, "loader_path": 1, "onc": 1, "describ": 1, "util": [1, 3, 5, 133], "__name__": [1, 3], "__main__": [1, 3], "descript": [1, 3, 203], "ext_modul": 1, "cmakeextens": 1, "cmdclass": 1, "build_ext": 1, "cmakebuild": 1, "package_dir": 1, "package_data": 1, "dylib": 1, "zip_saf": 1, "fals": [1, 3, 12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 55, 58, 59, 73, 104, 106, 108, 109, 118, 149, 156, 159, 166, 168, 170, 174, 183, 184, 203, 207], "python_requir": 1, "even": [1, 3], "though": [1, 3], "j8": 1, "libmlx_ext": 1, "cpython": 1, "3x": 1, "darwin": 1, "pip": [1, 6], "after": [1, 3, 4, 24, 117, 166, 167, 170, 214], "plai": [1, 3], "ones": [1, 3, 115, 133, 207], "b": [1, 3, 11, 13, 59, 69, 70, 91, 92, 95, 96, 102, 105, 107, 110, 111, 148, 155, 213, 214], "f": [1, 2, 4, 206], "item": [1, 2, 3, 4, 185, 213], "true": [1, 2, 3, 59, 132, 141, 159, 162, 163, 166, 167, 168, 174, 184, 203, 207], "quick": [1, 5], "benchmark": 1, "compar": [1, 59], "time": [1, 3, 6, 206, 214], "set_default_devic": 1, "256": [1, 4], "512": [1, 3, 214], "random": [1, 2, 3, 4, 5, 214, 215], "normal": [1, 2, 3, 127, 166, 167, 172, 182, 214], "bench": 1, "warm": 1, "rang": [1, 2, 3, 4, 6, 15, 188, 189, 209, 210, 214], "100": [1, 2, 3, 214], "5000": 1, "simple_tim": 1, "custom_tim": 1, "3f": [1, 4], "custom": 1, "114": 1, "109": 1, "modest": 1, "improv": [1, 3], "awai": [1, 3], "good": [1, 6, 214], "nn": [1, 3, 4, 133, 185, 206, 209], "grad": [1, 2, 4, 155, 209, 213], "simplifi": 1, "full": [1, 4, 47, 57, 64, 141, 207], "implement": [2, 4, 164, 170, 174, 176, 178, 180, 181, 182, 201, 207], "basic": 2, "model": [2, 4, 5, 133, 170, 179, 185, 206, 207, 209], "problem": [2, 4, 206], "metadata": 2, "num_featur": 2, "num_exampl": 2, "1_000": 2, "num_it": 2, "10_000": 2, "iter": [2, 4, 185, 210], "sgd": [2, 4, 209], "lr": 2, "01": 2, "rate": 2, "ll": [2, 4], "synthet": 2, "dataset": 2, "matrix": [2, 76, 93, 105], "ground": [2, 3], "truth": 2, "w_star": 2, "valu": [2, 3, 10, 15, 22, 23, 37, 56, 59, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 116, 119, 120, 121, 123, 124, 127, 128, 150, 151, 155, 158, 170, 178, 179, 182, 184, 185, 190, 191, 193, 194, 195, 201, 203, 207], "gaussian": [2, 165, 187, 188, 189], "nois": 2, "exampl": [2, 3, 4, 15, 150, 190, 209, 210, 213], "noisi": 2, "label": [2, 190], "ep": [2, 166, 167, 172, 180], "1e": [2, 4, 13, 166, 167, 172, 180], "us": [2, 3, 4, 5, 6, 15, 105, 130, 164, 165, 168, 170, 181, 184, 188, 189, 206, 207, 209, 210, 212, 213, 214], "weight": [2, 62, 63, 183, 185, 206, 207], "squar": [2, 3, 93, 131, 144, 155, 172, 185, 194, 206], "loss": [2, 4, 155, 209], "loss_fn": [2, 4, 209], "w": [2, 63, 155, 163, 183], "mean": [2, 3, 4, 155, 166, 172, 190, 191, 192, 193, 194, 195, 206, 207], "grad_fn": 2, "initi": [2, 3, 166, 167, 172, 206, 207], "randomli": [2, 3], "Then": [2, 6], "repeatedli": 2, "_": [2, 3, 206, 210, 214], "verifi": 2, "close": [2, 5, 13], "error_norm": 2, "5f": 2, "someth": [2, 3], "00005": 2, "00364": 2, "complet": [2, 3, 207, 214], "logist": [2, 137, 177, 188, 189, 200], "github": [2, 4, 6], "repo": [2, 4, 6], "enabl": [3, 73, 183], "larg": [3, 206], "ish": 3, "transform": [3, 5, 73, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 166, 167, 168, 179, 206, 207], "compromis": 3, "eas": 3, "llama": 3, "famili": 3, "less": [3, 24, 96, 117, 174], "200": 3, "line": 3, "python": [3, 37, 49, 56, 73, 184, 185, 186, 207, 212], "neural": [3, 5, 164, 169, 196, 207], "network": [3, 5, 164, 207], "build": [3, 5, 207], "concis": 3, "architectur": [3, 214], "notabl": 3, "rope": 3, "posit": [3, 24, 90, 117, 155, 162, 163, 170, 174, 185, 206], "option": [3, 12, 14, 15, 22, 23, 24, 25, 26, 31, 32, 61, 62, 63, 64, 73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 93, 104, 106, 108, 109, 114, 116, 117, 118, 119, 120, 121, 123, 124, 126, 127, 128, 130, 132, 141, 142, 143, 146, 149, 150, 151, 154, 155, 156, 158, 160, 162, 163, 170, 183, 184, 190, 191, 192, 193, 194, 195, 207, 210, 215], "kei": [3, 119, 120, 121, 123, 124, 126, 127, 128, 170, 182, 184, 185, 207, 210, 212], "cach": 3, "concaten": 3, "project": [3, 170], "llamaattent": 3, "self": [3, 4, 7, 9, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 50, 52, 53, 54, 55, 56, 57, 58, 169, 196, 206, 207], "dim": [3, 164, 166, 167, 170, 172, 174], "num_head": [3, 170], "super": [3, 4, 206, 207], "tradit": [3, 174], "query_proj": 3, "bia": [3, 162, 163, 168, 170, 180, 185, 207], "key_proj": 3, "value_proj": 3, "out_proj": [3, 207], "__call__": [3, 4, 206, 207], "queri": [3, 170], "mask": [3, 170], "extract": [3, 206, 207], "l": [3, 4, 162, 206], "reshap": 3, "combin": 3, "key_cach": 3, "value_cach": 3, "sqrt": [3, 71, 166, 167, 172, 180], "score": 3, "softmax": [3, 191], "values_hat": 3, "rm": 3, "swiglu": 3, "rmsnorm": 3, "llamaencoderlay": 3, "mlp_dim": 3, "norm1": 3, "norm2": 3, "linear1": 3, "linear2": 3, "linear3": 3, "sigmoid": [3, 177, 188, 189, 190, 200], "instanc": [3, 176, 186, 206, 207], "embed": 3, "emb": [3, 164], "token": [3, 164], "num_lay": [3, 4, 209], "vocab_s": 3, "norm": [3, 166], "multiheadattent": 3, "create_additive_causal_mask": 3, "list": [3, 8, 12, 14, 26, 29, 30, 40, 41, 42, 43, 45, 49, 52, 55, 56, 58, 60, 61, 73, 75, 78, 79, 81, 82, 84, 85, 87, 88, 89, 90, 94, 104, 106, 108, 109, 114, 116, 118, 119, 120, 121, 123, 124, 127, 128, 130, 141, 143, 146, 149, 154, 155, 156, 157, 160, 180, 184, 186, 206, 207, 212], "still": [3, 6], "consid": [3, 13, 59, 166, 184, 212], "train": [3, 4, 207], "ignor": 3, "whatsoev": 3, "rest": [3, 174, 185], "subsect": 3, "prompt": 3, "autoregress": 3, "yield": [3, 4, 210], "temp": 3, "causal": 3, "save": [3, 97, 133, 134, 207], "append": [3, 105], "store": 3, "per": [3, 4, 166, 167, 172, 181], "care": 3, "last": [3, 25, 56, 79, 82, 84, 85, 87, 88, 105, 120, 142, 162, 163, 166], "logit": [3, 120, 191], "next": [3, 4], "categor": 3, "lazili": [3, 206], "noth": [3, 206], "yet": [3, 206, 207, 213], "forc": [3, 4, 206, 213], "choos": [3, 174], "pars": 3, "feed": 3, "loop": [3, 4], "unsqueez": 3, "sequenc": [3, 162, 210, 214], "length": [3, 146, 162], "len": [3, 79, 82, 85, 88], "overwrit": 3, "discard": [3, 184], "old": 3, "moment": [3, 180], "anymor": 3, "everyth": 3, "small": [3, 166, 167, 172, 214], "10": [3, 4, 99, 133, 185, 206], "12": 3, "8192": 3, "1024": 3, "actual": [3, 15, 207], "materi": [3, 5], "could": [3, 206], "20_000": 3, "machin": [3, 5, 6], "8gb": 3, "ram": 3, "32": [3, 4, 203], "44": 3, "doubl": 3, "bracket": 3, "becaus": [3, 206], "batch": [3, 105, 162, 163, 170], "zip": [3, 4], "haven": 3, "anyth": [3, 155], "result": [3, 15, 56, 97, 105, 159, 185], "similar": [3, 170, 207], "runtim": 3, "section": [3, 143], "access": [3, 37, 206, 207, 214], "origin": [3, 180], "sentencepiec": 3, "pytorch": [3, 5, 166], "compat": [3, 120], "npz": [3, 97, 133, 134, 207], "file": [3, 6, 97, 132, 133, 134, 207], "directli": 3, "argpars": 3, "itertool": [3, 185], "starmap": [3, 185], "np": [3, 4, 213], "torch": 3, "map_torch_to_mlx": 3, "tok_embed": 3, "elif": 3, "replac": [3, 207], "attention_norm": 3, "ffn_norm": 3, "wq": 3, "wk": 3, "wv": 3, "wo": 3, "w1": 3, "w2": 3, "w3": 3, "ffn": 3, "separ": [3, 47, 57, 166], "submodul": [3, 4, 206, 207], "feed_forward": 3, "parser": 3, "argumentpars": 3, "add_argu": 3, "torch_weight": 3, "output_fil": 3, "parse_arg": 3, "state": [3, 4, 181, 182, 206, 209, 210], "savez": 3, "k": [3, 76, 207], "v": [3, 64, 207], "left": [3, 165, 174, 188, 189], "disk": 3, "text": [3, 169, 175, 178, 196, 197, 199, 201], "format": [3, 97, 132, 133, 134], "oper": [3, 5, 33, 141, 147, 151, 206, 213, 214, 215], "dictionari": [3, 181, 182, 184, 206, 207, 212], "represent": [3, 184, 186], "tree_unflatten": 3, "helper": 3, "weight_fil": 3, "incur": 3, "sever": [3, 62, 63, 133, 134], "unnecessari": [1, 3], "futur": 3, "pth": 3, "current": [3, 5, 6, 62, 63, 206], "around": 3, "m1": [3, 214], "ultra": 3, "7b": 3, "me": 3, "ishmael": 3, "year": 3, "ago": 3, "never": 3, "long": 3, "info": 3, "247": 3, "press": 3, "enter": 3, "littl": 3, "monei": 3, "my": [3, 6], "purs": 3, "greater": [3, 24, 92, 117, 178, 201], "consequ": 3, "walk": 3, "down": 3, "gower": 3, "street": 3, "afternoon": 3, "heavi": 3, "rain": 3, "saw": 3, "off": [3, 6], "man": 3, "rag": 3, "who": 3, "sat": 3, "upon": [3, 185], "hi": 3, "bundl": 3, "hard": 3, "wet": 3, "he": 3, "were": [3, 214], "cry": 3, "watch": 3, "him": 3, "observ": 3, "numer": [3, 102, 104, 141, 166, 167, 172], "crowd": 3, "wa": [3, 182], "hurri": 3, "437": 3, "330": 3, "second": [3, 105, 155, 180, 214], "spent": 3, "amount": 3, "39": 3, "ms": 3, "By": 3, "bigger": 3, "remain": [3, 155], "almost": 3, "nobodi": 3, "took": 3, "least": 3, "notic": 3, "distanc": 3, "had": 3, "doubt": 3, "minut": 3, "straight": 3, "slowli": 3, "rais": [3, 143], "ey": 3, "speak": 3, "resum": 3, "postur": 3, "stood": 3, "feel": 3, "pain": 3, "heart": 3, "smile": 3, "face": 3, "am": 3, "someon": 3, "three": 3, "quarter": 3, "hour": 3, "made": 3, "immedi": [3, 207], "repli": 3, "again": [3, 206], "hand": 3, "did": 3, "accustom": 3, "thu": [3, 206], "question": 3, "reason": 3, "tell": 3, "understand": 3, "579": 3, "690": 3, "num": [3, 126], "500": [3, 214], "628": 3, "went": 3, "nervou": 3, "trembl": 3, "told": 3, "And": 3, "perhap": 3, "surpris": 3, "matter": [3, 206], "shall": 3, "anyhow": 3, "friend": 3, "ye": 3, "slight": 3, "kind": 3, "longer": [3, 64], "soon": 3, "unless": [3, 207], "unlik": [3, 13], "strang": 3, "amus": 3, "That": 3, "secret": 3, "disappoint": 3, "mine": 3, "cannot": 3, "happi": 3, "ask": 3, "Is": 3, "shop": 3, "bui": 3, "food": 3, "633": 3, "21": 3, "475": 3, "su": 3, "j": [3, 6, 180], "lu": 3, "pan": 3, "murtadha": 3, "wen": 3, "liu": 3, "2021": 3, "roform": 3, "enhanc": 3, "rotari": [3, 174], "arxiv": [3, 166, 167, 169, 172, 174, 196], "preprint": 3, "2104": [3, 174], "09864": [3, 174], "zhang": 3, "sennrich": 3, "2019": 3, "root": [3, 131, 144, 172], "advanc": 3, "inform": [3, 4, 165, 170, 214], "system": 3, "shazeer": 3, "2020": 3, "glu": 3, "variant": 3, "2002": 3, "05202": 3, "classifi": 4, "mnist": 4, "As": [4, 150], "mlp": [4, 206, 209], "inherit": [4, 212], "standard": [4, 37, 56, 105, 121, 213], "idiom": 4, "input_dim": [4, 168], "hidden_dim": [4, 207, 209], "output_dim": [4, 168], "layer_s": 4, "idim": 4, "odim": 4, "maximum": [4, 22, 173, 188, 189, 198, 206, 207], "cross": [4, 190, 191], "entropi": [4, 190, 191], "sub": [4, 126], "commonli": [4, 207], "cross_entropi": 4, "accuraci": 4, "valid": [4, 64, 158, 184, 207, 212], "eval_fn": 4, "argmax": 4, "num_class": [4, 209], "batch_siz": [4, 209], "num_epoch": [4, 209], "learning_r": [4, 180, 183, 209], "train_imag": [4, 209], "train_label": [4, 209], "test_imag": 4, "test_label": 4, "shuffl": 4, "minibatch": 4, "batch_iter": [4, 209], "perm": 4, "permut": 4, "id": 4, "put": 4, "trainabl": [4, 179, 206, 207], "loss_and_grad_fn": [4, 209], "value_and_grad": [4, 206, 207, 209, 213], "epoch": 4, "test": [4, 6], "confus": 4, "decent": 4, "95": 4, "except": [5, 76, 83, 84, 86, 87, 88, 166], "featur": [5, 62, 63, 166, 167, 168, 172, 174], "main": [5, 76, 185, 206], "differ": [5, 148], "lazi": [5, 207, 213], "multi": [5, 162, 163], "cpu": [5, 214], "gpu": [5, 214], "strongli": [], "inspir": 5, "jax": [5, 210], "arrayfir": 5, "noteabl": 5, "unifi": 5, "live": [5, 214], "guid": 5, "regress": 5, "layer": [5, 166, 167, 168, 176, 207], "perceptron": 5, "llm": 5, "infer": [5, 89], "fft": 5, "tree": [5, 73, 90, 155, 158, 181, 184, 185, 186], "develop": [5, 6], "document": [5, 47, 57], "17": 6, "g": [6, 183, 206, 215], "clang": 6, "cmake": 6, "24": 6, "clone": 6, "git": 6, "com": 6, "ml": 6, "explor": 6, "cd": 6, "brew": 6, "conda": 6, "global": [6, 125, 210], "env": 6, "cmake_build_parallel_level": 6, "edit": 6, "unittest": 6, "discov": 6, "mkdir": 6, "p": [6, 119, 180], "either": [6, 11, 47, 56, 57, 69, 70, 91, 92, 95, 96, 102, 105, 107, 110, 111, 148, 155, 176], "libmlx": 6, "preprocessor": 6, "metal_path": 6, "mlx_build_test": 6, "ON": 6, "mlx_build_exampl": 6, "mlx_build_benchmark": 6, "mlx_build_python_bind": 6, "devicetyp": 7, "attribut": [7, 8, 9, 26], "kwarg": [8, 133, 134, 215], "union": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 50, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117, 118, 119, 120, 121, 123, 124, 126, 127, 128, 129, 130, 131, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 159, 160, 161, 163, 207], "wise": [10, 11, 16, 17, 18, 19, 20, 21, 65, 66, 69, 70, 71, 72, 74, 91, 92, 95, 96, 98, 99, 100, 101, 102, 103, 107, 110, 111, 112, 129, 131, 137, 138, 139, 140, 144, 145, 148, 152, 153, 169, 177, 196, 197, 200], "absolut": [10, 13, 188, 189], "semant": [11, 60, 69, 70, 91, 92, 95, 96, 102, 105, 107, 110, 111, 148, 214], "keepdim": [12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 55, 58, 104, 106, 108, 109, 118, 141, 149, 156], "reduct": [12, 14, 104, 106, 109, 118, 190, 191, 192, 193, 194, 195], "reduc": [12, 14, 22, 23, 104, 106, 108, 109, 118, 149, 156], "unspecifi": [12, 14, 15, 22, 23, 24, 25, 61, 89, 104, 106, 108, 109, 114, 117, 118, 141, 142, 149, 150, 156, 160, 215], "entir": [12, 14, 22, 23, 104, 106, 108, 109, 118, 149, 156], "singleton": [12, 14, 22, 23, 104, 105, 106, 108, 109, 118, 149, 156], "rtol": 13, "05": [13, 166, 167, 172], "atol": 13, "08": [13, 180], "approxim": [13, 165, 187, 188, 189], "comparison": [13, 70, 91, 92, 95, 96], "equal": [13, 24, 59, 76, 92, 96, 117, 124, 143], "ab": [13, 155, 166, 167, 169, 172, 174, 196], "array_equ": 13, "rel": 13, "toler": 13, "boolean": [13, 59, 103, 203], "interv": [15, 124, 128], "increment": 15, "otherwis": [15, 178, 184, 201, 207], "int32": [15, 124, 203, 213], "convent": [15, 64], "lead": 15, "fraction": 15, "integr": [15, 150], "invers": [16, 17, 18, 19, 20, 21, 72, 80, 81, 82, 83, 84, 85], "cosin": [16, 17, 65, 66], "hyperbol": [17, 19, 21, 66, 140, 153], "sine": [18, 19, 139, 140], "minimum": [22, 23], "kth": [24, 117], "partit": 24, "order": [24, 117, 166, 176, 206, 207], "undefin": [24, 117], "sort": [24, 25, 117], "partiton": 24, "flatten": [24, 25, 117, 142, 150, 151, 184], "dimension": [26, 77, 78, 79, 80, 81, 82, 86, 87, 88, 162, 163, 164, 168], "overload": [], "val": [26, 89], "tupl": [26, 47, 57, 61, 63, 73, 75, 94, 116, 130, 146, 155, 157, 163, 176, 184, 185, 186, 207, 212], "ndarrai": [26, 213], "properti": [27, 35, 44, 49, 51], "argument": [27, 47, 57, 73, 90, 132, 155, 185, 206, 210, 214, 215], "elment": 51, "indices_or_sect": [52, 143], "nest": [56, 206, 207, 212], "correpsond": 56, "ddof": [58, 156], "equal_nan": 59, "nan": 59, "pad": [62, 63, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 162, 163], "dilat": [62, 63], "group": [62, 63, 166], "1d": [62, 64, 151], "convolut": [62, 63, 64, 162, 163], "channel": [62, 63, 162, 163], "c_in": [62, 63], "c_out": [62, 63], "convolv": [62, 63], "2d": 63, "spatial": [63, 166], "symmetr": 63, "discret": [64, 77, 78, 79, 80, 81, 82, 86, 87, 88, 164], "swap": 64, "conv": 64, "filter": [64, 162, 163, 207], "flip": 64, "signal": 64, "divis": 69, "quotient": 69, "mathrm": [71, 137], "frac": [71, 137, 166, 167, 172, 180], "pi": 71, "int_0": 71, "dx": 71, "erf": 72, "retain_graph": [73, 132], "node": [73, 158], "dict": [73, 97, 133, 207, 212], "leaf": [73, 184, 207], "preserv": [73, 130], "intend": 73, "control": [73, 210], "flow": [73, 147], "exponenti": [74, 175, 199], "insert": [75, 214], "One": [77, 80, 86, 131], "fourier": [77, 78, 79, 80, 81, 82, 86, 87, 88], "truncat": [77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 127], "zero": [76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 161, 206], "dft": [77, 78, 79, 80, 81, 82, 86, 87, 88], "rfft": 83, "real": [83, 84, 85, 86, 87, 88], "rfft2": 84, "rfftn": 85, "silent": [86, 87, 88], "fun": [90, 94, 155, 157, 158, 214], "cpp_function": [90, 155, 158], "variabl": [6, 90, 94, 155, 157, 158], "strict": [91, 95, 207], "binari": [97, 132, 133, 134, 178, 190, 201], "npy": [97, 132], "natur": [98, 100], "logarithm": [98, 99, 100, 101], "log": [100, 102, 104, 192, 195], "plu": 100, "exp": [102, 104, 121, 141, 175, 192, 199, 214], "stabl": [102, 104, 141], "multipl": [6, 105, 111, 170], "prepend": 105, "remov": [105, 120, 146], "anoth": [105, 148, 159, 207, 214], "negat": 112, "pad_width": 116, "constant_valu": 116, "edg": 116, "before_1": 116, "after_1": 116, "before_2": 116, "after_2": 116, "before_n": 116, "after_n": 116, "integ": [116, 119, 124, 143, 158, 164, 203], "before_i": 116, "after_i": 116, "extend": 116, "side": 116, "smaller": 117, "distribut": [6, 119, 120, 121, 123, 127, 128, 192, 195], "prng": [119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 210], "num_sampl": 120, "unnorm": 120, "draw": 120, "uint32": [120, 203], "cdf": [121, 165, 187], "accord": [121, 159, 170], "seed": 122, "low": [124, 128], "high": [124, 128, 164, 206], "probabl": [6, 124, 190, 192, 214], "lower": [124, 127, 128], "upper": [124, 127, 128], "bound": [124, 127, 128, 165, 214], "roadcast": 124, "domain": 127, "optino": 127, "uniformli": 128, "reciproc": 131, "arr": 132, "retain": 132, "dure": 132, "uncompress": 133, "my_path": 133, "tree_flatten": [133, 186], "transformerencod": 133, "128": [133, 206], "flat_param": 133, "keyword": [90, 133, 134, 155, 206, 210, 215], "compress": 134, "subarrai": 143, "being": [147, 206], "ident": [76, 147], "prevent": 147, "unchang": [147, 174], "taken": 150, "prior": [150, 151], "equial": 150, "exclud": 151, "mse": 155, "param": [155, 206], "lvalu": 155, "dlvalu": 155, "dparam": 155, "lasso": 155, "l1": [155, 193], "varianc": [156, 166], "divisor": 156, "cotang": 157, "in_ax": 158, "out_ax": 158, "prefix": [158, 184], "select": [159, 207], "in_channel": [162, 163], "out_channel": [162, 163], "kernel_s": [162, 163], "appli": [162, 163, 165, 166, 167, 168, 169, 172, 173, 175, 177, 178, 181, 185, 187, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 207], "nlc": 162, "learnabl": [162, 163, 176], "nhwc": 163, "height": 163, "width": 163, "num_embed": 164, "lookup": 164, "tabl": [164, 203], "typic": [164, 209], "usual": [164, 212], "vocabulari": 164, "approx": 165, "unit": [165, 173, 175, 177, 187, 188, 189, 198, 199, 200], "textrm": [165, 187], "phi": [165, 187], "geluapprox": 165, "sigma": [165, 177, 188, 189, 200], "60033": [165, 188], "0433603": [165, 188], "gelufast": 165, "773": [165, 189], "gelu_approx": [165, 187], "gelu_fast_approx": [165, 187], "regard": 165, "num_group": 166, "affin": [166, 167, 168], "pytorch_compat": 166, "var": [166, 167], "epsilon": [166, 167, 172, 180], "gamma": [166, 167, 172], "particular": 166, "split": 166, "preced": 166, "http": [166, 167, 169, 172, 174, 196], "org": [166, 167, 169, 172, 174, 196], "1803": 166, "08494": 166, "stabil": [166, 167, 172], "1607": 167, "06450": 167, "query_input_dim": 170, "key_input_dim": 170, "value_input_dim": 170, "value_dim": 170, "value_output_dim": 170, "dot": [170, 184, 207], "attent": [170, 207], "head": 170, "aggreg": 170, "lineari": [], "bias": [170, 207], "inf": 170, "neg": [170, 195], "attend": 170, "1910": 172, "07467": 172, "rectifi": [173, 198], "rotat": 174, "consecut": 174, "larger": 174, "slightli": [174, 214], "callabl": [176, 179, 184, 185, 207], "plain": 176, "cdot": [177, 188, 189, 200], "fn": [179, 185, 213], "wrt": 179, "whose": [76, 179], "9": 180, "999": 180, "paper": 180, "omit": 180, "estim": 180, "m_": 180, "beta_1": 180, "m_t": 180, "g_t": [180, 183], "v_": [180, 183], "beta_2": 180, "v_t": [180, 183], "w_": [180, 183], "w_t": [180, 183], "lambda": [175, 180, 183, 185, 199, 207], "kingma": 180, "ba": 180, "2015": 180, "stochast": [180, 183], "iclr": 180, "basi": 181, "optimizerst": 181, "recurs": [182, 206, 207], "defaultdict": 182, "miss": 182, "contrast": 182, "present": 182, "momentum": 183, "descent": 183, "mu": 183, "strength": 183, "is_leaf": 184, "notat": [184, 207], "arbitrari": [184, 207], "depth": 184, "hello": [184, 186], "charact": 184, "flat": [184, 186], "everi": 185, "superset": 185, "extra": 185, "closer": 185, "dict_kei": 185, "recreat": 186, "world": 186, "42": 186, "faster": 187, "gelu": [188, 189], "exact": [188, 189], "0003": 188, "015": 189, "show": [6, 203], "byte": 203, "bool_": 203, "uint8": 203, "unsign": 203, "uint16": 203, "16": [203, 207], "int8": 203, "sign": 203, "int16": 203, "int64": 203, "64": 203, "arm": [6, 203], "arbitrarili": [206, 212, 213], "done": 206, "manual": 206, "explicitli": [206, 210], "solv": 206, "intuit": 206, "freez": [206, 207], "finetun": 206, "in_dim": [206, 207], "out_dim": [206, 207], "enumer": 206, "caus": 206, "local": 206, "scope": 206, "l2_loss": 206, "y_hat": 206, "trainable_paramet": [206, 207], "loss_and_grad": 206, "workhors": 206, "Its": 206, "frozen": [206, 207], "subset": [206, 207], "individu": 206, "action": 206, "preclud": 206, "pure": [206, 209], "pattern": 206, "achiev": 206, "other_input": 206, "necessari": 206, "wrap": 206, "subclass": 207, "concept": 207, "mymlp": 207, "in_proj": 207, "map_fn": 207, "filter_fn": 207, "valid_parameter_filt": 207, "apply_to_modul": 207, "apply_fn": 207, "children": 207, "descend": 207, "filter_and_map": 207, "is_leaf_fn": 207, "content": [6, 207], "found": 207, "whether": 207, "drop": 207, "idempot": 207, "ie": 207, "noop": 207, "unfreez": 207, "endswith": 207, "leaf_modul": 207, "load_weight": 207, "named_modul": 207, "save_weight": 207, "unfrozen": 207, "chang": 207, "tracer": 207, "partial": 207, "subsequ": 209, "implicit": 210, "fine": 210, "grain": 210, "manag": [210, 214], "uniform": [210, 214], "pseudo": 210, "altern": 210, "splittabl": 210, "threefri": 210, "counter": 210, "cycl": 212, "inspect": 213, "composit": 213, "sin": 213, "default_stream": 215, "default_devic": 215, "my_devic": 215, "brought": 5, "research": 5, "maco": 6, "13": 6, "recommend": 6, "14": 6, "sonoma": 6, "xcode": 6, "15": 6, "wish": 6, "environ": 6, "export": 6, "developer_dir": 6, "app": 6, "sdk": 6, "xcrun": 6, "macosx": 6, "meet": 6, "seri": 6, "chip": 6, "nativ": 6, "platform": 6, "processor": 6, "i386": 6, "switch": 6, "argnam": [90, 155], "neither": [90, 155], "pool": 214, "advantag": 214, "don": 214, "parallel": 214, "race": 214, "interest": 214, "albeit": 214, "contriv": 214, "suppos": 214, "d1": 214, "d2": 214, "matmul": 214, "4096": 214, "dens": 214, "better": 214, "overhead": 214, "millisecond": 214, "twice": 214, "measur": 214, "diagon": 76, "th": 76, "pad_with": 116, "regular": [169, 196], "monoton": [169, 196], "refer": [169, 196], "1908": [169, 196], "08681": [169, 196], "tanh": [169, 196], "softplu": [169, 196], "linearli": 170, "num_paramet": 171, "init": 171, "25": 171, "begin": [175, 178, 199, 201], "leq": [175, 199], "0507": [175, 199], "67326": [175, 199], "elu": [175, 199], "known": [177, 200], "swish": [177, 200], "threshold": [178, 201], "geq": [178, 201], "weight_decai": 183, "dampen": 183, "nesterov": 183, "decai": 183, "l2": 183, "penalti": 183, "tau": 183, "predict": [190, 191, 192, 193, 194, 195], "post": 190, "612192": 190, "kullback": 192, "leibler": 192, "diverg": 192, "likelihood": 195, "nll": 195}, "objects": {"mlx.core": [[7, 0, 1, "", "Device"], [8, 0, 1, "", "Dtype"], [9, 0, 1, "", "Stream"], [10, 2, 1, "", "abs"], [11, 2, 1, "", "add"], [12, 2, 1, "", "all"], [13, 2, 1, "", "allclose"], [14, 2, 1, "", "any"], [15, 2, 1, "", "arange"], [16, 2, 1, "", "arccos"], [17, 2, 1, "", "arccosh"], [18, 2, 1, "", "arcsin"], [19, 2, 1, "", "arcsinh"], [20, 2, 1, "", "arctan"], [21, 2, 1, "", "arctanh"], [22, 2, 1, "", "argmax"], [23, 2, 1, "", "argmin"], [24, 2, 1, "", "argpartition"], [25, 2, 1, "", "argsort"], [26, 0, 1, "", "array"], [59, 2, 1, "", "array_equal"], [60, 2, 1, "", "broadcast_to"], [61, 2, 1, "", "concatenate"], [62, 2, 1, "", "conv1d"], [63, 2, 1, "", "conv2d"], [64, 2, 1, "", "convolve"], [65, 2, 1, "", "cos"], [66, 2, 1, "", "cosh"], [67, 2, 1, "", "default_device"], [68, 2, 1, "", "default_stream"], [69, 2, 1, "", "divide"], [70, 2, 1, "", "equal"], [71, 2, 1, "", "erf"], [72, 2, 1, "", "erfinv"], [73, 2, 1, "", "eval"], [74, 2, 1, "", "exp"], [75, 2, 1, "", "expand_dims"], [76, 2, 1, "", "eye"], [89, 2, 1, "", "full"], [90, 2, 1, "", "grad"], [91, 2, 1, "", "greater"], [92, 2, 1, "", "greater_equal"], [93, 2, 1, "", "identity"], [94, 2, 1, "", "jvp"], [95, 2, 1, "", "less"], [96, 2, 1, "", "less_equal"], [97, 2, 1, "", "load"], [98, 2, 1, "", "log"], [99, 2, 1, "", "log10"], [100, 2, 1, "", "log1p"], [101, 2, 1, "", "log2"], [102, 2, 1, "", "logaddexp"], [103, 2, 1, "", "logical_not"], [104, 2, 1, "", "logsumexp"], [105, 2, 1, "", "matmul"], [106, 2, 1, "", "max"], [107, 2, 1, "", "maximum"], [108, 2, 1, "", "mean"], [109, 2, 1, "", "min"], [110, 2, 1, "", "minimum"], [111, 2, 1, "", "multiply"], [112, 2, 1, "", "negative"], [113, 2, 1, "", "new_stream"], [114, 2, 1, "", "ones"], [115, 2, 1, "", "ones_like"], [116, 2, 1, "", "pad"], [117, 2, 1, "", "partition"], [118, 2, 1, "", "prod"], [129, 2, 1, "", "reciprocal"], [130, 2, 1, "", "reshape"], [131, 2, 1, "", "rsqrt"], [132, 2, 1, "", "save"], [133, 2, 1, "", "savez"], [134, 2, 1, "", "savez_compressed"], [135, 2, 1, "", "set_default_device"], [136, 2, 1, "", "set_default_stream"], [137, 2, 1, "", "sigmoid"], [138, 2, 1, "", "sign"], [139, 2, 1, "", "sin"], [140, 2, 1, "", "sinh"], [141, 2, 1, "", "softmax"], [142, 2, 1, "", "sort"], [143, 2, 1, "", "split"], [144, 2, 1, "", "sqrt"], [145, 2, 1, "", "square"], [146, 2, 1, "", "squeeze"], [147, 2, 1, "", "stop_gradient"], [148, 2, 1, "", "subtract"], [149, 2, 1, "", "sum"], [150, 2, 1, "", "take"], [151, 2, 1, "", "take_along_axis"], [152, 2, 1, "", "tan"], [153, 2, 1, "", "tanh"], [154, 2, 1, "", "transpose"], [155, 2, 1, "", "value_and_grad"], [156, 2, 1, "", "var"], [157, 2, 1, "", "vjp"], [158, 2, 1, "", "vmap"], [159, 2, 1, "", "where"], [160, 2, 1, "", "zeros"], [161, 2, 1, "", "zeros_like"]], "mlx.core.Device": [[7, 1, 1, "", "__init__"]], "mlx.core.Dtype": [[8, 1, 1, "", "__init__"]], "mlx.core.Stream": [[9, 1, 1, "", "__init__"]], "mlx.core.array": [[27, 3, 1, "", "T"], [26, 1, 1, "", "__init__"], [28, 1, 1, "", "abs"], [29, 1, 1, "", "all"], [30, 1, 1, "", "any"], [31, 1, 1, "", "argmax"], [32, 1, 1, "", "argmin"], [33, 1, 1, "", "astype"], [34, 1, 1, "", "cos"], [35, 3, 1, "", "dtype"], [36, 1, 1, "", "exp"], [37, 1, 1, "", "item"], [38, 1, 1, "", "log"], [39, 1, 1, "", "log1p"], [40, 1, 1, "", "logsumexp"], [41, 1, 1, "", "max"], [42, 1, 1, "", "mean"], [43, 1, 1, "", "min"], [44, 3, 1, "", "ndim"], [45, 1, 1, "", "prod"], [46, 1, 1, "", "reciprocal"], [47, 1, 1, "", "reshape"], [48, 1, 1, "", "rsqrt"], [49, 3, 1, "", "shape"], [50, 1, 1, "", "sin"], [51, 3, 1, "", "size"], [52, 1, 1, "", "split"], [53, 1, 1, "", "sqrt"], [54, 1, 1, "", "square"], [55, 1, 1, "", "sum"], [56, 1, 1, "", "tolist"], [57, 1, 1, "", "transpose"], [58, 1, 1, "", "var"]], "mlx.core.fft": [[77, 2, 1, "", "fft"], [78, 2, 1, "", "fft2"], [79, 2, 1, "", "fftn"], [80, 2, 1, "", "ifft"], [81, 2, 1, "", "ifft2"], [82, 2, 1, "", "ifftn"], [83, 2, 1, "", "irfft"], [84, 2, 1, "", "irfft2"], [85, 2, 1, "", "irfftn"], [86, 2, 1, "", "rfft"], [87, 2, 1, "", "rfft2"], [88, 2, 1, "", "rfftn"]], "mlx.core.random": [[119, 2, 1, "", "bernoulli"], [120, 2, 1, "", "categorical"], [121, 2, 1, "", "gumbel"], [122, 2, 1, "", "key"], [123, 2, 1, "", "normal"], [124, 2, 1, "", "randint"], [125, 2, 1, "", "seed"], [126, 2, 1, "", "split"], [127, 2, 1, "", "truncated_normal"], [128, 2, 1, "", "uniform"]], "mlx.nn": [[162, 0, 1, "", "Conv1d"], [163, 0, 1, "", "Conv2d"], [164, 0, 1, "", "Embedding"], [165, 0, 1, "", "GELU"], [166, 0, 1, "", "GroupNorm"], [167, 0, 1, "", "LayerNorm"], [168, 0, 1, "", "Linear"], [169, 0, 1, "", "Mish"], [207, 0, 1, "", "Module"], [170, 0, 1, "", "MultiHeadAttention"], [171, 0, 1, "", "PReLU"], [172, 0, 1, "", "RMSNorm"], [173, 0, 1, "", "ReLU"], [174, 0, 1, "", "RoPE"], [175, 0, 1, "", "SELU"], [176, 0, 1, "", "Sequential"], [177, 0, 1, "", "SiLU"], [178, 0, 1, "", "Step"], [187, 0, 1, "", "gelu"], [188, 0, 1, "", "gelu_approx"], [189, 0, 1, "", "gelu_fast_approx"], [196, 0, 1, "", "mish"], [197, 0, 1, "", "prelu"], [198, 0, 1, "", "relu"], [199, 0, 1, "", "selu"], [200, 0, 1, "", "silu"], [201, 0, 1, "", "step"], [179, 2, 1, "", "value_and_grad"]], "mlx.nn.Module": [[207, 1, 1, "", "apply"], [207, 1, 1, "", "apply_to_modules"], [207, 1, 1, "", "children"], [207, 1, 1, "", "filter_and_map"], [207, 1, 1, "", "freeze"], [207, 1, 1, "", "leaf_modules"], [207, 1, 1, "", "load_weights"], [207, 1, 1, "", "modules"], [207, 1, 1, "", "named_modules"], [207, 1, 1, "", "parameters"], [207, 1, 1, "", "save_weights"], [207, 1, 1, "", "trainable_parameters"], [207, 1, 1, "", "unfreeze"], [207, 1, 1, "", "update"]], "mlx.nn.losses": [[190, 0, 1, "", "binary_cross_entropy"], [191, 0, 1, "", "cross_entropy"], [192, 0, 1, "", "kl_div_loss"], [193, 0, 1, "", "l1_loss"], [194, 0, 1, "", "mse_loss"], [195, 0, 1, "", "nll_loss"]], "mlx.optimizers": [[180, 0, 1, "", "Adam"], [181, 0, 1, "", "Optimizer"], [182, 0, 1, "", "OptimizerState"], [183, 0, 1, "", "SGD"]], "mlx.optimizers.Optimizer": [[181, 4, 1, "", "state"]], "mlx.utils": [[184, 2, 1, "", "tree_flatten"], [185, 2, 1, "", "tree_map"], [186, 2, 1, "", "tree_unflatten"]]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function", "3": "py:property", "4": "py:attribute"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"], "3": ["py", "property", "Python property"], "4": ["py", "attribute", "Python attribute"]}, "titleterms": {"oper": [0, 1, 208], "develop": 1, "document": 1, "introduc": 1, "exampl": [1, 5, 214], "primit": 1, "us": [1, 215], "implement": [1, 3], "cpu": 1, "backend": 1, "gpu": 1, "transform": [1, 211, 213], "build": [1, 6], "bind": 1, "python": [1, 5, 6], "cmake": 1, "setuptool": 1, "usag": [1, 5], "result": 1, "script": [1, 3], "download": [1, 3], "code": [1, 3], "linear": [2, 168], "regress": 2, "llm": 3, "infer": 3, "model": 3, "attent": 3, "layer": [3, 4, 206], "encod": 3, "full": [3, 89], "gener": 3, "put": 3, "all": [3, 12, 29], "togeth": 3, "convert": 3, "weight": 3, "load": [3, 97], "benchmark": 3, "multi": 4, "perceptron": 4, "mlx": [5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 207], "instal": [5, 6], "api": [5, 6], "refer": 5, "c": [5, 6], "further": 5, "read": 5, "from": 6, "pypi": 6, "sourc": 6, "requir": 6, "option": 6, "core": [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161], "devic": [7, 204], "dtype": [8, 35], "stream": [9, 204, 215], "ab": [10, 28], "add": 11, "allclos": 13, "ani": [14, 30], "arang": 15, "arcco": 16, "arccosh": 17, "arcsin": 18, "arcsinh": 19, "arctan": 20, "arctanh": 21, "argmax": [22, 31], "argmin": [23, 32], "argpartit": 24, "argsort": 25, "arrai": [26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 202], "t": 27, "astyp": 33, "co": [34, 65], "exp": [36, 74], "item": 37, "log": [38, 98], "log1p": [39, 100], "logsumexp": [40, 104], "max": [41, 106], "mean": [42, 108], "min": [43, 109], "ndim": 44, "prod": [45, 118], "reciproc": [46, 129], "reshap": [47, 130], "rsqrt": [48, 131], "shape": 49, "sin": [50, 139], "size": 51, "split": [52, 126, 143], "sqrt": [53, 144], "squar": [54, 145], "sum": [55, 149], "tolist": 56, "transpos": [57, 154], "var": [58, 156], "array_equ": 59, "broadcast_to": 60, "concaten": 61, "conv1d": [62, 162], "conv2d": [63, 163], "convolv": 64, "cosh": 66, "default_devic": 67, "default_stream": 68, "divid": 69, "equal": 70, "erf": 71, "erfinv": 72, "eval": 73, "expand_dim": 75, "fft": [77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 205], "fft2": 78, "fftn": 79, "ifft": 80, "ifft2": 81, "ifftn": 82, "irfft": 83, "irfft2": 84, "irfftn": 85, "rfft": 86, "rfft2": 87, "rfftn": 88, "grad": [90, 206], "greater": 91, "greater_equ": 92, "jvp": 94, "less": 95, "less_equ": 96, "log10": 99, "log2": 101, "logaddexp": 102, "logical_not": 103, "matmul": 105, "maximum": 107, "minimum": 110, "multipli": 111, "neg": 112, "new_stream": 113, "ones": 114, "ones_lik": 115, "pad": 116, "partit": 117, "random": [119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 210], "bernoulli": 119, "categor": 120, "gumbel": 121, "kei": 122, "normal": 123, "randint": 124, "seed": 125, "truncated_norm": 127, "uniform": 128, "save": 132, "savez": 133, "savez_compress": 134, "set_default_devic": 135, "set_default_stream": 136, "sigmoid": 137, "sign": 138, "sinh": 140, "softmax": 141, "sort": 142, "squeez": 146, "stop_gradi": 147, "subtract": 148, "take": 150, "take_along_axi": 151, "tan": 152, "tanh": 153, "value_and_grad": [155, 179], "vjp": 157, "vmap": 158, "where": 159, "zero": 160, "zeros_lik": 161, "nn": [162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 207], "embed": 164, "gelu": [165, 187], "groupnorm": 166, "layernorm": 167, "multiheadattent": 170, "rmsnorm": 172, "relu": [173, 198], "rope": 174, "sequenti": 176, "silu": [177, 200], "optim": [180, 181, 182, 183, 209], "adam": 180, "optimizerst": 182, "sgd": 183, "util": [184, 185, 186, 212], "tree_flatten": 184, "tree_map": 185, "tree_unflatten": 186, "gelu_approx": 188, "gelu_fast_approx": 189, "data": 203, "type": 203, "support": 203, "neural": 206, "network": 206, "quick": [206, 213], "start": [206, 213], "The": 206, "modul": [206, 207], "class": 206, "paramet": 206, "updat": 206, "valu": 206, "tree": 212, "guid": 213, "basic": 213, "function": [206, 213], "graph": 213, "specifi": 215, "troubleshoot": 6, "unifi": 214, "memori": 214, "A": 214, "simpl": 214, "ey": 76, "ident": 93, "mish": [169, 196], "prelu": [171, 197], "selu": [175, 199], "step": [178, 201], "loss": [190, 191, 192, 193, 194, 195, 206], "binary_cross_entropi": 190, "cross_entropi": 191, "kl_div_loss": 192, "l1_loss": 193, "mse_loss": 194, "nll_loss": 195}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}})