From ffe51a69ca3d6def78074d09268a3b87da47433d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 17 Dec 2023 13:23:03 -0800 Subject: [PATCH] docs --- docs/build/html/_sources/dev/extensions.rst | 2 +- docs/build/html/_sources/examples/mlp.rst | 5 +- docs/build/html/_sources/install.rst | 44 +- .../python/_autosummary/mlx.core.array.rst | 3 + .../python/_autosummary/mlx.core.ceil.rst | 6 + .../python/_autosummary/mlx.core.flatten.rst | 6 + .../python/_autosummary/mlx.core.floor.rst | 6 + .../python/_autosummary/mlx.core.moveaxis.rst | 6 + .../python/_autosummary/mlx.core.simplify.rst | 6 + .../python/_autosummary/mlx.core.stack.rst | 6 + .../python/_autosummary/mlx.core.swapaxes.rst | 6 + .../python/_autosummary/mlx.core.tri.rst | 6 + .../python/_autosummary/mlx.core.tril.rst | 6 + .../python/_autosummary/mlx.core.triu.rst | 6 + .../python/_autosummary/mlx.nn.Module.rst | 58 ++ .../_autosummary/mlx.optimizers.AdaDelta.rst | 18 + .../_autosummary/mlx.optimizers.Adagrad.rst | 18 + .../_autosummary/mlx.optimizers.AdamW.rst | 18 + .../_autosummary/mlx.optimizers.Adamax.rst | 18 + .../_autosummary/mlx.optimizers.RMSprop.rst | 18 + docs/build/html/_sources/python/nn.rst | 113 ++- .../{ => nn}/_autosummary/mlx.nn.Conv1d.rst | 0 .../{ => nn}/_autosummary/mlx.nn.Conv2d.rst | 0 .../_autosummary/mlx.nn.Embedding.rst | 0 .../{ => nn}/_autosummary/mlx.nn.GELU.rst | 0 .../_autosummary/mlx.nn.GroupNorm.rst | 0 .../_autosummary/mlx.nn.LayerNorm.rst | 0 .../{ => nn}/_autosummary/mlx.nn.Linear.rst | 0 .../{ => nn}/_autosummary/mlx.nn.Mish.rst | 0 .../mlx.nn.MultiHeadAttention.rst | 0 .../{ => nn}/_autosummary/mlx.nn.PReLU.rst | 0 .../{ => nn}/_autosummary/mlx.nn.RMSNorm.rst | 0 .../{ => nn}/_autosummary/mlx.nn.ReLU.rst | 0 .../{ => nn}/_autosummary/mlx.nn.RoPE.rst | 0 .../{ => nn}/_autosummary/mlx.nn.SELU.rst | 0 .../_autosummary/mlx.nn.Sequential.rst | 0 .../{ => nn}/_autosummary/mlx.nn.SiLU.rst | 0 .../{ => nn}/_autosummary/mlx.nn.Step.rst | 0 .../_autosummary_functions/mlx.nn.gelu.rst | 0 .../mlx.nn.gelu_approx.rst | 0 .../mlx.nn.gelu_fast_approx.rst | 0 .../mlx.nn.losses.binary_cross_entropy.rst | 0 .../mlx.nn.losses.cross_entropy.rst | 0 .../mlx.nn.losses.kl_div_loss.rst | 0 .../mlx.nn.losses.l1_loss.rst | 0 .../mlx.nn.losses.mse_loss.rst | 0 .../mlx.nn.losses.nll_loss.rst | 0 .../_autosummary_functions/mlx.nn.mish.rst | 0 .../_autosummary_functions/mlx.nn.prelu.rst | 0 .../_autosummary_functions/mlx.nn.relu.rst | 0 .../_autosummary_functions/mlx.nn.selu.rst | 0 .../_autosummary_functions/mlx.nn.silu.rst | 0 .../_autosummary_functions/mlx.nn.step.rst | 0 .../html/_sources/python/nn/functions.rst | 23 + docs/build/html/_sources/python/nn/layers.rst | 28 + docs/build/html/_sources/python/nn/losses.rst | 17 + docs/build/html/_sources/python/nn/module.rst | 7 - docs/build/html/_sources/python/ops.rst | 9 + .../build/html/_sources/python/optimizers.rst | 5 + .../build/html/_sources/python/transforms.rst | 1 + docs/build/html/cpp/ops.html | 93 +- docs/build/html/dev/extensions.html | 97 +- .../html/examples/linear_regression.html | 93 +- docs/build/html/examples/llama-inference.html | 101 +- docs/build/html/examples/mlp.html | 101 +- docs/build/html/genindex.html | 227 +++-- docs/build/html/index.html | 93 +- docs/build/html/install.html | 141 ++- docs/build/html/objects.inv | Bin 5020 -> 5279 bytes .../python/_autosummary/mlx.core.Device.html | 93 +- .../python/_autosummary/mlx.core.Dtype.html | 93 +- .../python/_autosummary/mlx.core.Stream.html | 93 +- .../python/_autosummary/mlx.core.abs.html | 93 +- .../python/_autosummary/mlx.core.add.html | 93 +- .../python/_autosummary/mlx.core.all.html | 93 +- .../_autosummary/mlx.core.allclose.html | 93 +- .../python/_autosummary/mlx.core.any.html | 93 +- .../python/_autosummary/mlx.core.arange.html | 93 +- .../python/_autosummary/mlx.core.arccos.html | 93 +- .../python/_autosummary/mlx.core.arccosh.html | 93 +- .../python/_autosummary/mlx.core.arcsin.html | 93 +- .../python/_autosummary/mlx.core.arcsinh.html | 93 +- .../python/_autosummary/mlx.core.arctan.html | 93 +- .../python/_autosummary/mlx.core.arctanh.html | 93 +- .../python/_autosummary/mlx.core.argmax.html | 93 +- .../python/_autosummary/mlx.core.argmin.html | 93 +- .../_autosummary/mlx.core.argpartition.html | 93 +- .../python/_autosummary/mlx.core.argsort.html | 93 +- .../python/_autosummary/mlx.core.array.T.html | 93 +- .../_autosummary/mlx.core.array.abs.html | 93 +- .../_autosummary/mlx.core.array.all.html | 93 +- .../_autosummary/mlx.core.array.any.html | 93 +- .../_autosummary/mlx.core.array.argmax.html | 93 +- .../_autosummary/mlx.core.array.argmin.html | 93 +- .../_autosummary/mlx.core.array.astype.html | 93 +- .../_autosummary/mlx.core.array.cos.html | 93 +- .../_autosummary/mlx.core.array.dtype.html | 93 +- .../_autosummary/mlx.core.array.exp.html | 93 +- .../python/_autosummary/mlx.core.array.html | 126 ++- .../_autosummary/mlx.core.array.item.html | 93 +- .../_autosummary/mlx.core.array.log.html | 93 +- .../_autosummary/mlx.core.array.log1p.html | 93 +- .../mlx.core.array.logsumexp.html | 93 +- .../_autosummary/mlx.core.array.max.html | 93 +- .../_autosummary/mlx.core.array.mean.html | 93 +- .../_autosummary/mlx.core.array.min.html | 93 +- .../_autosummary/mlx.core.array.ndim.html | 93 +- .../_autosummary/mlx.core.array.prod.html | 93 +- .../mlx.core.array.reciprocal.html | 93 +- .../_autosummary/mlx.core.array.reshape.html | 93 +- .../_autosummary/mlx.core.array.rsqrt.html | 93 +- .../_autosummary/mlx.core.array.shape.html | 93 +- .../_autosummary/mlx.core.array.sin.html | 93 +- .../_autosummary/mlx.core.array.size.html | 93 +- .../_autosummary/mlx.core.array.split.html | 93 +- .../_autosummary/mlx.core.array.sqrt.html | 93 +- .../_autosummary/mlx.core.array.square.html | 93 +- .../_autosummary/mlx.core.array.sum.html | 93 +- .../_autosummary/mlx.core.array.tolist.html | 93 +- .../mlx.core.array.transpose.html | 93 +- .../_autosummary/mlx.core.array.var.html | 93 +- .../_autosummary/mlx.core.array_equal.html | 93 +- .../_autosummary/mlx.core.broadcast_to.html | 99 +- ...{mlx.nn.Conv1d.html => mlx.core.ceil.html} | 155 ++-- .../_autosummary/mlx.core.concatenate.html | 99 +- .../python/_autosummary/mlx.core.conv1d.html | 93 +- .../python/_autosummary/mlx.core.conv2d.html | 93 +- .../_autosummary/mlx.core.convolve.html | 93 +- .../python/_autosummary/mlx.core.cos.html | 93 +- .../python/_autosummary/mlx.core.cosh.html | 93 +- .../_autosummary/mlx.core.default_device.html | 93 +- .../_autosummary/mlx.core.default_stream.html | 93 +- .../python/_autosummary/mlx.core.divide.html | 93 +- .../python/_autosummary/mlx.core.equal.html | 93 +- .../python/_autosummary/mlx.core.erf.html | 93 +- .../python/_autosummary/mlx.core.erfinv.html | 93 +- .../python/_autosummary/mlx.core.eval.html | 93 +- .../python/_autosummary/mlx.core.exp.html | 93 +- .../_autosummary/mlx.core.expand_dims.html | 93 +- .../python/_autosummary/mlx.core.eye.html | 99 +- .../python/_autosummary/mlx.core.fft.fft.html | 93 +- .../_autosummary/mlx.core.fft.fft2.html | 93 +- .../_autosummary/mlx.core.fft.fftn.html | 93 +- .../_autosummary/mlx.core.fft.ifft.html | 93 +- .../_autosummary/mlx.core.fft.ifft2.html | 93 +- .../_autosummary/mlx.core.fft.ifftn.html | 93 +- .../_autosummary/mlx.core.fft.irfft.html | 93 +- .../_autosummary/mlx.core.fft.irfft2.html | 93 +- .../_autosummary/mlx.core.fft.irfftn.html | 93 +- .../_autosummary/mlx.core.fft.rfft.html | 93 +- .../_autosummary/mlx.core.fft.rfft2.html | 93 +- .../_autosummary/mlx.core.fft.rfftn.html | 93 +- .../python/_autosummary/mlx.core.flatten.html | 695 ++++++++++++++ .../{mlx.nn.RoPE.html => mlx.core.floor.html} | 146 +-- .../python/_autosummary/mlx.core.full.html | 99 +- .../python/_autosummary/mlx.core.grad.html | 93 +- .../python/_autosummary/mlx.core.greater.html | 93 +- .../_autosummary/mlx.core.greater_equal.html | 93 +- .../_autosummary/mlx.core.identity.html | 93 +- .../python/_autosummary/mlx.core.jvp.html | 93 +- .../python/_autosummary/mlx.core.less.html | 93 +- .../_autosummary/mlx.core.less_equal.html | 93 +- .../python/_autosummary/mlx.core.load.html | 93 +- .../python/_autosummary/mlx.core.log.html | 93 +- .../python/_autosummary/mlx.core.log10.html | 93 +- .../python/_autosummary/mlx.core.log1p.html | 93 +- .../python/_autosummary/mlx.core.log2.html | 93 +- .../_autosummary/mlx.core.logaddexp.html | 93 +- .../_autosummary/mlx.core.logical_not.html | 93 +- .../_autosummary/mlx.core.logsumexp.html | 93 +- .../python/_autosummary/mlx.core.matmul.html | 93 +- .../python/_autosummary/mlx.core.max.html | 93 +- .../python/_autosummary/mlx.core.maximum.html | 93 +- .../python/_autosummary/mlx.core.mean.html | 93 +- .../python/_autosummary/mlx.core.min.html | 93 +- .../python/_autosummary/mlx.core.minimum.html | 99 +- .../_autosummary/mlx.core.moveaxis.html | 693 ++++++++++++++ .../_autosummary/mlx.core.multiply.html | 99 +- .../_autosummary/mlx.core.negative.html | 93 +- .../_autosummary/mlx.core.new_stream.html | 93 +- .../python/_autosummary/mlx.core.ones.html | 93 +- .../_autosummary/mlx.core.ones_like.html | 93 +- .../python/_autosummary/mlx.core.pad.html | 93 +- .../_autosummary/mlx.core.partition.html | 93 +- .../python/_autosummary/mlx.core.prod.html | 93 +- .../mlx.core.random.bernoulli.html | 93 +- .../mlx.core.random.categorical.html | 93 +- .../_autosummary/mlx.core.random.gumbel.html | 93 +- .../_autosummary/mlx.core.random.key.html | 93 +- .../_autosummary/mlx.core.random.normal.html | 93 +- .../_autosummary/mlx.core.random.randint.html | 93 +- .../_autosummary/mlx.core.random.seed.html | 93 +- .../_autosummary/mlx.core.random.split.html | 93 +- .../mlx.core.random.truncated_normal.html | 93 +- .../_autosummary/mlx.core.random.uniform.html | 93 +- .../_autosummary/mlx.core.reciprocal.html | 93 +- .../python/_autosummary/mlx.core.reshape.html | 93 +- .../python/_autosummary/mlx.core.rsqrt.html | 93 +- .../python/_autosummary/mlx.core.save.html | 93 +- .../python/_autosummary/mlx.core.savez.html | 93 +- .../mlx.core.savez_compressed.html | 93 +- .../mlx.core.set_default_device.html | 93 +- .../mlx.core.set_default_stream.html | 93 +- .../python/_autosummary/mlx.core.sigmoid.html | 93 +- .../python/_autosummary/mlx.core.sign.html | 93 +- ....nn.Linear.html => mlx.core.simplify.html} | 158 ++-- .../python/_autosummary/mlx.core.sin.html | 93 +- .../python/_autosummary/mlx.core.sinh.html | 93 +- .../python/_autosummary/mlx.core.softmax.html | 93 +- .../python/_autosummary/mlx.core.sort.html | 93 +- .../python/_autosummary/mlx.core.split.html | 93 +- .../python/_autosummary/mlx.core.sqrt.html | 93 +- .../python/_autosummary/mlx.core.square.html | 93 +- .../python/_autosummary/mlx.core.squeeze.html | 99 +- ...mlx.nn.Conv2d.html => mlx.core.stack.html} | 157 ++-- .../_autosummary/mlx.core.stop_gradient.html | 99 +- .../_autosummary/mlx.core.subtract.html | 93 +- .../python/_autosummary/mlx.core.sum.html | 99 +- .../_autosummary/mlx.core.swapaxes.html | 693 ++++++++++++++ .../python/_autosummary/mlx.core.take.html | 99 +- .../mlx.core.take_along_axis.html | 93 +- .../python/_autosummary/mlx.core.tan.html | 93 +- .../python/_autosummary/mlx.core.tanh.html | 93 +- .../_autosummary/mlx.core.transpose.html | 99 +- .../python/_autosummary/mlx.core.tri.html | 695 ++++++++++++++ .../python/_autosummary/mlx.core.tril.html | 693 ++++++++++++++ .../python/_autosummary/mlx.core.triu.html | 693 ++++++++++++++ .../_autosummary/mlx.core.value_and_grad.html | 93 +- .../python/_autosummary/mlx.core.var.html | 99 +- .../python/_autosummary/mlx.core.vjp.html | 93 +- .../python/_autosummary/mlx.core.vmap.html | 99 +- .../python/_autosummary/mlx.core.where.html | 93 +- .../python/_autosummary/mlx.core.zeros.html | 93 +- .../_autosummary/mlx.core.zeros_like.html | 93 +- .../html/python/_autosummary/mlx.nn.GELU.html | 669 -------------- .../html/python/_autosummary/mlx.nn.Mish.html | 658 ------------- ...x.nn.Embedding.html => mlx.nn.Module.html} | 279 ++++-- .../python/_autosummary/mlx.nn.PReLU.html | 652 ------------- .../html/python/_autosummary/mlx.nn.ReLU.html | 654 ------------- .../html/python/_autosummary/mlx.nn.SELU.html | 661 -------------- .../_autosummary/mlx.nn.Sequential.html | 661 -------------- .../html/python/_autosummary/mlx.nn.SiLU.html | 656 ------------- .../html/python/_autosummary/mlx.nn.Step.html | 666 -------------- .../_autosummary/mlx.nn.value_and_grad.html | 103 ++- ...Norm.html => mlx.optimizers.AdaDelta.html} | 156 ++-- ...pNorm.html => mlx.optimizers.Adagrad.html} | 161 ++-- .../_autosummary/mlx.optimizers.Adam.html | 121 ++- .../_autosummary/mlx.optimizers.AdamW.html | 713 +++++++++++++++ ...ention.html => mlx.optimizers.Adamax.html} | 168 ++-- .../mlx.optimizers.Optimizer.html | 93 +- .../mlx.optimizers.OptimizerState.html | 93 +- ...SNorm.html => mlx.optimizers.RMSprop.html} | 151 +-- .../_autosummary/mlx.optimizers.SGD.html | 111 ++- .../_autosummary/mlx.utils.tree_flatten.html | 93 +- .../_autosummary/mlx.utils.tree_map.html | 93 +- .../mlx.utils.tree_unflatten.html | 93 +- .../_autosummary_functions/mlx.nn.gelu.html | 659 ------------- .../mlx.nn.gelu_approx.html | 660 -------------- .../mlx.nn.gelu_fast_approx.html | 660 -------------- .../mlx.nn.losses.cross_entropy.html | 670 -------------- .../mlx.nn.losses.l1_loss.html | 669 -------------- .../mlx.nn.losses.mse_loss.html | 669 -------------- .../_autosummary_functions/mlx.nn.mish.html | 658 ------------- .../_autosummary_functions/mlx.nn.prelu.html | 657 ------------- .../_autosummary_functions/mlx.nn.relu.html | 654 ------------- .../_autosummary_functions/mlx.nn.selu.html | 661 -------------- .../_autosummary_functions/mlx.nn.silu.html | 656 ------------- .../_autosummary_functions/mlx.nn.step.html | 666 -------------- docs/build/html/python/array.html | 93 +- docs/build/html/python/data_types.html | 93 +- .../html/python/devices_and_streams.html | 93 +- docs/build/html/python/fft.html | 99 +- docs/build/html/python/nn.html | 326 ++++--- .../python/nn/_autosummary/mlx.nn.Conv1d.html | 701 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.Conv2d.html | 702 ++++++++++++++ .../nn/_autosummary/mlx.nn.Embedding.html | 689 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.GELU.html | 694 ++++++++++++++ .../nn/_autosummary/mlx.nn.GroupNorm.html | 703 ++++++++++++++ .../nn/_autosummary/mlx.nn.LayerNorm.html | 695 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.Linear.html | 693 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.Mish.html | 683 ++++++++++++++ .../mlx.nn.MultiHeadAttention.html | 700 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.PReLU.html | 677 ++++++++++++++ .../nn/_autosummary/mlx.nn.RMSNorm.html | 693 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.ReLU.html | 679 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.RoPE.html | 694 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.SELU.html | 686 ++++++++++++++ .../nn/_autosummary/mlx.nn.Sequential.html | 686 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.SiLU.html | 681 ++++++++++++++ .../python/nn/_autosummary/mlx.nn.Step.html | 691 ++++++++++++++ .../_autosummary_functions/mlx.nn.gelu.html | 684 ++++++++++++++ .../mlx.nn.gelu_approx.html | 685 ++++++++++++++ .../mlx.nn.gelu_fast_approx.html | 685 ++++++++++++++ .../mlx.nn.losses.binary_cross_entropy.html | 704 ++++++++++++++ .../mlx.nn.losses.cross_entropy.html | 695 ++++++++++++++ .../mlx.nn.losses.kl_div_loss.html | 700 ++++++++++++++ .../mlx.nn.losses.l1_loss.html | 694 ++++++++++++++ .../mlx.nn.losses.mse_loss.html | 694 ++++++++++++++ .../mlx.nn.losses.nll_loss.html | 695 ++++++++++++++ .../_autosummary_functions/mlx.nn.mish.html | 683 ++++++++++++++ .../_autosummary_functions/mlx.nn.prelu.html | 682 ++++++++++++++ .../_autosummary_functions/mlx.nn.relu.html | 679 ++++++++++++++ .../_autosummary_functions/mlx.nn.selu.html | 686 ++++++++++++++ .../_autosummary_functions/mlx.nn.silu.html | 681 ++++++++++++++ .../_autosummary_functions/mlx.nn.step.html | 691 ++++++++++++++ .../functions.html} | 173 ++-- .../layers.html} | 200 ++-- .../losses.html} | 171 ++-- docs/build/html/python/nn/module.html | 863 ------------------ docs/build/html/python/ops.html | 200 ++-- docs/build/html/python/optimizers.html | 114 ++- docs/build/html/python/random.html | 93 +- docs/build/html/python/transforms.html | 96 +- docs/build/html/python/tree_utils.html | 99 +- docs/build/html/quick_start.html | 93 +- docs/build/html/search.html | 93 +- docs/build/html/searchindex.js | 2 +- docs/build/html/unified_memory.html | 93 +- docs/build/html/using_streams.html | 93 +- 319 files changed, 39952 insertions(+), 21579 deletions(-) create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.Conv1d.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.Conv2d.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.Embedding.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.GELU.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.GroupNorm.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.LayerNorm.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.Linear.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.Mish.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.MultiHeadAttention.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.PReLU.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.RMSNorm.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.ReLU.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.RoPE.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.SELU.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.Sequential.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.SiLU.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary/mlx.nn.Step.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.gelu.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.gelu_approx.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.gelu_fast_approx.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.losses.cross_entropy.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.losses.l1_loss.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.losses.mse_loss.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.losses.nll_loss.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.mish.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.prelu.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.relu.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.selu.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.silu.rst (100%) rename docs/build/html/_sources/python/{ => nn}/_autosummary_functions/mlx.nn.step.rst (100%) create mode 100644 docs/build/html/_sources/python/nn/functions.rst create mode 100644 docs/build/html/_sources/python/nn/layers.rst create mode 100644 docs/build/html/_sources/python/nn/losses.rst delete mode 100644 docs/build/html/_sources/python/nn/module.rst rename docs/build/html/python/_autosummary/{mlx.nn.Conv1d.html => mlx.core.ceil.html} (75%) create mode 100644 docs/build/html/python/_autosummary/mlx.core.flatten.html rename docs/build/html/python/_autosummary/{mlx.nn.RoPE.html => mlx.core.floor.html} (74%) create mode 100644 docs/build/html/python/_autosummary/mlx.core.moveaxis.html rename docs/build/html/python/_autosummary/{mlx.nn.Linear.html => mlx.core.simplify.html} (74%) rename docs/build/html/python/_autosummary/{mlx.nn.Conv2d.html => mlx.core.stack.html} (72%) create mode 100644 docs/build/html/python/_autosummary/mlx.core.swapaxes.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.tri.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.tril.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.triu.html delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.GELU.html delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.Mish.html rename docs/build/html/python/_autosummary/{mlx.nn.Embedding.html => mlx.nn.Module.html} (60%) delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.PReLU.html delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.ReLU.html delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.SELU.html delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.Sequential.html delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.SiLU.html delete mode 100644 docs/build/html/python/_autosummary/mlx.nn.Step.html rename docs/build/html/python/_autosummary/{mlx.nn.LayerNorm.html => mlx.optimizers.AdaDelta.html} (73%) rename docs/build/html/python/_autosummary/{mlx.nn.GroupNorm.html => mlx.optimizers.Adagrad.html} (75%) create mode 100644 docs/build/html/python/_autosummary/mlx.optimizers.AdamW.html rename docs/build/html/python/_autosummary/{mlx.nn.MultiHeadAttention.html => mlx.optimizers.Adamax.html} (72%) rename docs/build/html/python/_autosummary/{mlx.nn.RMSNorm.html => mlx.optimizers.RMSprop.html} (74%) delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.gelu.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.gelu_approx.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.gelu_fast_approx.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.losses.cross_entropy.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.losses.l1_loss.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.losses.mse_loss.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.mish.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.prelu.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.relu.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.selu.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.silu.html delete mode 100644 docs/build/html/python/_autosummary_functions/mlx.nn.step.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.Conv1d.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.Conv2d.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.Embedding.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.GELU.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.GroupNorm.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.LayerNorm.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.Linear.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.Mish.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.PReLU.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.RMSNorm.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.ReLU.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.RoPE.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.SELU.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.Sequential.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.SiLU.html create mode 100644 docs/build/html/python/nn/_autosummary/mlx.nn.Step.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.gelu.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.mish.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.prelu.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.relu.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.selu.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.silu.html create mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.step.html rename docs/build/html/python/{_autosummary_functions/mlx.nn.losses.nll_loss.html => nn/functions.html} (78%) rename docs/build/html/python/{_autosummary_functions/mlx.nn.losses.kl_div_loss.html => nn/layers.html} (74%) rename docs/build/html/python/{_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html => nn/losses.html} (79%) delete mode 100644 docs/build/html/python/nn/module.html diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst index 9482be725..9aae931a3 100644 --- a/docs/build/html/_sources/dev/extensions.rst +++ b/docs/build/html/_sources/dev/extensions.rst @@ -150,7 +150,7 @@ back and go to our example to give ourselves a more concrete image. const std::vector& argnums) override; /** - * The primitive must know how to vectorize itself accross + * The primitive must know how to vectorize itself across * the given axes. The output is a pair containing the array * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. diff --git a/docs/build/html/_sources/examples/mlp.rst b/docs/build/html/_sources/examples/mlp.rst index c003618ce..36890e95c 100644 --- a/docs/build/html/_sources/examples/mlp.rst +++ b/docs/build/html/_sources/examples/mlp.rst @@ -61,7 +61,10 @@ set: def eval_fn(model, X, y): return mx.mean(mx.argmax(model(X), axis=1) == y) -Next, setup the problem parameters and load the data: +Next, setup the problem parameters and load the data. To load the data, you need our +`mnist data loader +`_, which +we will import as `mnist`. .. code-block:: python diff --git a/docs/build/html/_sources/install.rst b/docs/build/html/_sources/install.rst index 682f09f38..92669ab6e 100644 --- a/docs/build/html/_sources/install.rst +++ b/docs/build/html/_sources/install.rst @@ -15,11 +15,11 @@ To install from PyPI you must meet the following requirements: - Using an M series chip (Apple silicon) - Using a native Python >= 3.8 -- MacOS >= 13.3 +- macOS >= 13.3 .. note:: - MLX is only available on devices running MacOS >= 13.3 - It is highly recommended to use MacOS 14 (Sonoma) + MLX is only available on devices running macOS >= 13.3 + It is highly recommended to use macOS 14 (Sonoma) Troubleshooting ^^^^^^^^^^^^^^^ @@ -35,8 +35,7 @@ Probably you are using a non-native Python. The output of should be ``arm``. If it is ``i386`` (and you have M series machine) then you are using a non-native Python. Switch your Python to a native Python. A good -way to do this is with -`Conda `_. +way to do this is with `Conda `_. Build from source @@ -47,7 +46,7 @@ Build Requirements - A C++ compiler with C++17 support (e.g. Clang >= 5.0) - `cmake `_ -- version 3.24 or later, and ``make`` -- Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above) +- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above) Python API @@ -88,6 +87,13 @@ To make sure the install is working run the tests with: pip install ".[testing]" python -m unittest discover python/tests +Optional: Install stubs to enable auto completions and type checking from your IDE: + +.. code-block:: shell + + pip install ".[dev]" + python setup.py generate_stubs + C++ API ^^^^^^^ @@ -154,8 +160,32 @@ should point to the path to the built metal library. export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/" Further, you can use the following command to find out which - MacOS SDK will be used + macOS SDK will be used .. code-block:: shell xcrun -sdk macosx --show-sdk-version + +Troubleshooting +^^^^^^^^^^^^^^^ + +Metal not found +~~~~~~~~~~~~~~~ + +You see the following error when you try to build: + +.. code-block:: shell + + error: unable to find utility "metal", not a developer tool or in PATH + +To fix this, first make sure you have Xcode installed: + +.. code-block:: shell + + xcode-select --install + +Then set the active developer directory: + +.. code-block:: shell + + sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst index 21e66b5e4..a93bbadcd 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst @@ -26,6 +26,7 @@ ~array.cumprod ~array.cumsum ~array.exp + ~array.flatten ~array.item ~array.log ~array.log10 @@ -35,6 +36,7 @@ ~array.max ~array.mean ~array.min + ~array.moveaxis ~array.prod ~array.reciprocal ~array.reshape @@ -45,6 +47,7 @@ ~array.square ~array.squeeze ~array.sum + ~array.swapaxes ~array.tolist ~array.transpose ~array.var diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst new file mode 100644 index 000000000..bbd0a6656 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.ceil.rst @@ -0,0 +1,6 @@ +mlx.core.ceil +============= + +.. currentmodule:: mlx.core + +.. autofunction:: ceil \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst new file mode 100644 index 000000000..90470d914 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.flatten.rst @@ -0,0 +1,6 @@ +mlx.core.flatten +================ + +.. currentmodule:: mlx.core + +.. autofunction:: flatten \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst new file mode 100644 index 000000000..a05f6d451 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.floor.rst @@ -0,0 +1,6 @@ +mlx.core.floor +============== + +.. currentmodule:: mlx.core + +.. autofunction:: floor \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst new file mode 100644 index 000000000..ed69d670c --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.moveaxis.rst @@ -0,0 +1,6 @@ +mlx.core.moveaxis +================= + +.. currentmodule:: mlx.core + +.. autofunction:: moveaxis \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst new file mode 100644 index 000000000..c0b518497 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst @@ -0,0 +1,6 @@ +mlx.core.simplify +================= + +.. currentmodule:: mlx.core + +.. autofunction:: simplify \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst new file mode 100644 index 000000000..fdb8721a2 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.stack.rst @@ -0,0 +1,6 @@ +mlx.core.stack +============== + +.. currentmodule:: mlx.core + +.. autofunction:: stack \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst new file mode 100644 index 000000000..07b724a0f --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.swapaxes.rst @@ -0,0 +1,6 @@ +mlx.core.swapaxes +================= + +.. currentmodule:: mlx.core + +.. autofunction:: swapaxes \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst new file mode 100644 index 000000000..ef760035b --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.tri.rst @@ -0,0 +1,6 @@ +mlx.core.tri +============ + +.. currentmodule:: mlx.core + +.. autofunction:: tri \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst new file mode 100644 index 000000000..89b45b090 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.tril.rst @@ -0,0 +1,6 @@ +mlx.core.tril +============= + +.. currentmodule:: mlx.core + +.. autofunction:: tril \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst new file mode 100644 index 000000000..1d6aa7626 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.triu.rst @@ -0,0 +1,6 @@ +mlx.core.triu +============= + +.. currentmodule:: mlx.core + +.. autofunction:: triu \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst b/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst new file mode 100644 index 000000000..79f55b253 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.nn.Module.rst @@ -0,0 +1,58 @@ +mlx.nn.Module +============= + +.. currentmodule:: mlx.nn + +.. autoclass:: Module + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~Module.__init__ + ~Module.apply + ~Module.apply_to_modules + ~Module.children + ~Module.clear + ~Module.copy + ~Module.eval + ~Module.filter_and_map + ~Module.freeze + ~Module.fromkeys + ~Module.get + ~Module.is_module + ~Module.items + ~Module.keys + ~Module.leaf_modules + ~Module.load_weights + ~Module.modules + ~Module.named_modules + ~Module.parameters + ~Module.pop + ~Module.popitem + ~Module.save_weights + ~Module.setdefault + ~Module.train + ~Module.trainable_parameter_filter + ~Module.trainable_parameters + ~Module.unfreeze + ~Module.update + ~Module.valid_child_filter + ~Module.valid_parameter_filter + ~Module.values + + + + + + .. rubric:: Attributes + + .. autosummary:: + + ~Module.training + + \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst new file mode 100644 index 000000000..2ea7cda8a --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst @@ -0,0 +1,18 @@ +mlx.optimizers.AdaDelta +======================= + +.. currentmodule:: mlx.optimizers + +.. autoclass:: AdaDelta + + + + + .. rubric:: Methods + + .. autosummary:: + + ~AdaDelta.__init__ + ~AdaDelta.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst new file mode 100644 index 000000000..8a12fc43c --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst @@ -0,0 +1,18 @@ +mlx.optimizers.Adagrad +====================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Adagrad + + + + + .. rubric:: Methods + + .. autosummary:: + + ~Adagrad.__init__ + ~Adagrad.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst new file mode 100644 index 000000000..b5259844f --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdamW.rst @@ -0,0 +1,18 @@ +mlx.optimizers.AdamW +==================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: AdamW + + + + + .. rubric:: Methods + + .. autosummary:: + + ~AdamW.__init__ + ~AdamW.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst new file mode 100644 index 000000000..58e6c95ca --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst @@ -0,0 +1,18 @@ +mlx.optimizers.Adamax +===================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Adamax + + + + + .. rubric:: Methods + + .. autosummary:: + + ~Adamax.__init__ + ~Adamax.apply_single + + diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst new file mode 100644 index 000000000..217b4619f --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst @@ -0,0 +1,18 @@ +mlx.optimizers.RMSprop +====================== + +.. currentmodule:: mlx.optimizers + +.. autoclass:: RMSprop + + + + + .. rubric:: Methods + + .. autosummary:: + + ~RMSprop.__init__ + ~RMSprop.apply_single + + diff --git a/docs/build/html/_sources/python/nn.rst b/docs/build/html/_sources/python/nn.rst index 93cfd8c78..bc19a8162 100644 --- a/docs/build/html/_sources/python/nn.rst +++ b/docs/build/html/_sources/python/nn.rst @@ -64,7 +64,6 @@ Quick Start with Neural Networks # gradient with respect to `mlp.trainable_parameters()` loss_and_grad = nn.value_and_grad(mlp, l2_loss) - .. _module_class: The Module Class @@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other :meth:`Module.parameters` can be used to extract a nested dictionary with all the parameters of a module and its submodules. -A :class:`Module` can also keep track of "frozen" parameters. -:meth:`Module.trainable_parameters` returns only the subset of -:meth:`Module.parameters` that is not frozen. When using -:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these -trainable parameters. +A :class:`Module` can also keep track of "frozen" parameters. See the +:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad` +the gradients returned will be with respect to these trainable parameters. -Updating the parameters + +Updating the Parameters ^^^^^^^^^^^^^^^^^^^^^^^ MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module's parameters. This action is performed by :meth:`Module.update`. -Value and grad + +Inspecting Modules +^^^^^^^^^^^^^^^^^^ + +The simplest way to see the model architecture is to print it. Following along with +the above example, you can print the ``MLP`` with: + +.. code-block:: python + + print(mlp) + +This will display: + +.. code-block:: shell + + MLP( + (layers.0): Linear(input_dims=2, output_dims=128, bias=True) + (layers.1): Linear(input_dims=128, output_dims=128, bias=True) + (layers.2): Linear(input_dims=128, output_dims=10, bias=True) + ) + +To get more detailed information on the arrays in a :class:`Module` you can use +:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of +all the parameters in a :class:`Module` do: + +.. code-block:: python + + from mlx.utils import tree_map + shapes = tree_map(lambda p: p.shape, mlp.parameters()) + +As another example, you can count the number of parameters in a :class:`Module` +with: + +.. code-block:: python + + from mlx.utils import tree_flatten + num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) + + +Value and Grad -------------- Using a :class:`Module` does not preclude using MLX's high order function @@ -133,62 +170,14 @@ In detail: :meth:`mlx.core.value_and_grad` .. autosummary:: + :recursive: :toctree: _autosummary value_and_grad + Module -Neural Network Layers ---------------------- +.. toctree:: -.. autosummary:: - :toctree: _autosummary - :template: nn-module-template.rst - - Embedding - ReLU - PReLU - GELU - SiLU - Step - SELU - Mish - Linear - Conv1d - Conv2d - LayerNorm - RMSNorm - GroupNorm - RoPE - MultiHeadAttention - Sequential - -Layers without parameters (e.g. activation functions) are also provided as -simple functions. - -.. autosummary:: - :toctree: _autosummary_functions - :template: nn-module-template.rst - - gelu - gelu_approx - gelu_fast_approx - relu - prelu - silu - step - selu - mish - -Loss Functions --------------- - -.. autosummary:: - :toctree: _autosummary_functions - :template: nn-module-template.rst - - losses.cross_entropy - losses.binary_cross_entropy - losses.l1_loss - losses.mse_loss - losses.nll_loss - losses.kl_div_loss + nn/layers + nn/functions + nn/losses diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Conv1d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Conv1d.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Conv2d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Conv2d.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Embedding.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Embedding.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.GELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.GELU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.GroupNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.GroupNorm.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.LayerNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.LayerNorm.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Linear.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Linear.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Mish.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Mish.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.MultiHeadAttention.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.MultiHeadAttention.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.PReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.PReLU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.RMSNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.RMSNorm.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.ReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.ReLU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.RoPE.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.RoPE.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.SELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.SELU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Sequential.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Sequential.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.SiLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.SiLU.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst diff --git a/docs/build/html/_sources/python/_autosummary/mlx.nn.Step.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary/mlx.nn.Step.rst rename to docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_approx.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_fast_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.gelu_fast_approx.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.cross_entropy.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.l1_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.l1_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.mse_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.mse_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.nll_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.losses.nll_loss.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.mish.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.mish.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.prelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.prelu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.relu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.relu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.selu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.selu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.silu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.silu.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst diff --git a/docs/build/html/_sources/python/_autosummary_functions/mlx.nn.step.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst similarity index 100% rename from docs/build/html/_sources/python/_autosummary_functions/mlx.nn.step.rst rename to docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst diff --git a/docs/build/html/_sources/python/nn/functions.rst b/docs/build/html/_sources/python/nn/functions.rst new file mode 100644 index 000000000..f13cbe7b4 --- /dev/null +++ b/docs/build/html/_sources/python/nn/functions.rst @@ -0,0 +1,23 @@ +.. _nn_functions: + +.. currentmodule:: mlx.nn + +Functions +--------- + +Layers without parameters (e.g. activation functions) are also provided as +simple functions. + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + gelu + gelu_approx + gelu_fast_approx + relu + prelu + silu + step + selu + mish diff --git a/docs/build/html/_sources/python/nn/layers.rst b/docs/build/html/_sources/python/nn/layers.rst new file mode 100644 index 000000000..5628134d6 --- /dev/null +++ b/docs/build/html/_sources/python/nn/layers.rst @@ -0,0 +1,28 @@ +.. _layers: + +.. currentmodule:: mlx.nn + +Layers +------ + +.. autosummary:: + :toctree: _autosummary + :template: nn-module-template.rst + + Embedding + ReLU + PReLU + GELU + SiLU + Step + SELU + Mish + Linear + Conv1d + Conv2d + LayerNorm + RMSNorm + GroupNorm + RoPE + MultiHeadAttention + Sequential diff --git a/docs/build/html/_sources/python/nn/losses.rst b/docs/build/html/_sources/python/nn/losses.rst new file mode 100644 index 000000000..4808ce5ab --- /dev/null +++ b/docs/build/html/_sources/python/nn/losses.rst @@ -0,0 +1,17 @@ +.. _losses: + +.. currentmodule:: mlx.nn.losses + +Loss Functions +-------------- + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + cross_entropy + binary_cross_entropy + l1_loss + mse_loss + nll_loss + kl_div_loss diff --git a/docs/build/html/_sources/python/nn/module.rst b/docs/build/html/_sources/python/nn/module.rst deleted file mode 100644 index e14ba96f4..000000000 --- a/docs/build/html/_sources/python/nn/module.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlx.nn.Module -============= - -.. currentmodule:: mlx.nn - -.. autoclass:: Module - :members: diff --git a/docs/build/html/_sources/python/ops.rst b/docs/build/html/_sources/python/ops.rst index b9a4c9066..ea25b90f9 100644 --- a/docs/build/html/_sources/python/ops.rst +++ b/docs/build/html/_sources/python/ops.rst @@ -26,6 +26,7 @@ Operations argsort array_equal broadcast_to + ceil concatenate convolve conv1d @@ -39,6 +40,8 @@ Operations exp expand_dims eye + floor + flatten full greater greater_equal @@ -59,6 +62,7 @@ Operations mean min minimum + moveaxis multiply negative ones @@ -82,14 +86,19 @@ Operations sqrt square squeeze + stack stop_gradient subtract sum + swapaxes take take_along_axis tan tanh transpose + tri + tril + triu var where zeros diff --git a/docs/build/html/_sources/python/optimizers.rst b/docs/build/html/_sources/python/optimizers.rst index 7f5d3a067..b8e5cfea7 100644 --- a/docs/build/html/_sources/python/optimizers.rst +++ b/docs/build/html/_sources/python/optimizers.rst @@ -38,4 +38,9 @@ model's parameters and the **optimizer state**. OptimizerState Optimizer SGD + RMSprop + Adagrad + AdaDelta Adam + AdamW + Adamax diff --git a/docs/build/html/_sources/python/transforms.rst b/docs/build/html/_sources/python/transforms.rst index cc8d681d5..fa6d1d701 100644 --- a/docs/build/html/_sources/python/transforms.rst +++ b/docs/build/html/_sources/python/transforms.rst @@ -14,3 +14,4 @@ Transforms jvp vjp vmap + simplify diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html index cd7666c91..a6e660fe2 100644 --- a/docs/build/html/cpp/ops.html +++ b/docs/build/html/cpp/ops.html @@ -226,6 +226,7 @@
  • mlx.core.argsort
  • mlx.core.array_equal
  • mlx.core.broadcast_to
  • +
  • mlx.core.ceil
  • mlx.core.concatenate
  • mlx.core.convolve
  • mlx.core.conv1d
  • @@ -239,6 +240,8 @@
  • mlx.core.exp
  • mlx.core.expand_dims
  • mlx.core.eye
  • +
  • mlx.core.floor
  • +
  • mlx.core.flatten
  • mlx.core.full
  • mlx.core.greater
  • mlx.core.greater_equal
  • @@ -259,6 +262,7 @@
  • mlx.core.mean
  • mlx.core.min
  • mlx.core.minimum
  • +
  • mlx.core.moveaxis
  • mlx.core.multiply
  • mlx.core.negative
  • mlx.core.ones
  • @@ -282,14 +286,19 @@
  • mlx.core.sqrt
  • mlx.core.square
  • mlx.core.squeeze
  • +
  • mlx.core.stack
  • mlx.core.stop_gradient
  • mlx.core.subtract
  • mlx.core.sum
  • +
  • mlx.core.swapaxes
  • mlx.core.take
  • mlx.core.take_along_axis
  • mlx.core.tan
  • mlx.core.tanh
  • mlx.core.transpose
  • +
  • mlx.core.tri
  • +
  • mlx.core.tril
  • +
  • mlx.core.triu
  • mlx.core.var
  • mlx.core.where
  • mlx.core.zeros
  • @@ -316,6 +325,7 @@
  • mlx.core.jvp
  • mlx.core.vjp
  • mlx.core.vmap
  • +
  • mlx.core.simplify
  • FFT