From e7a536706b644d14fb4a4a2f1baac2bec2f4572b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 8 Feb 2024 12:44:23 -0800 Subject: [PATCH] docs update --- docs/build/html/.buildinfo | 2 +- docs/build/html/_sources/dev/extensions.rst | 16 +- docs/build/html/_sources/index.rst | 1 + .../python/_autosummary/mlx.core.compile.rst | 6 + .../_autosummary/mlx.core.disable_compile.rst | 6 + .../_autosummary/mlx.core.enable_compile.rst | 6 + .../python/_autosummary/mlx.core.simplify.rst | 6 - .../_autosummary/mlx.optimizers.AdaDelta.rst | 1 + .../_autosummary/mlx.optimizers.Adafactor.rst | 1 + .../_autosummary/mlx.optimizers.Adagrad.rst | 1 + .../_autosummary/mlx.optimizers.Adam.rst | 1 + .../_autosummary/mlx.optimizers.Adamax.rst | 1 + .../_autosummary/mlx.optimizers.Lion.rst | 1 + ...x.optimizers.Optimizer.apply_gradients.rst | 6 + .../mlx.optimizers.Optimizer.init.rst | 6 + .../_autosummary/mlx.optimizers.Optimizer.rst | 20 - .../mlx.optimizers.Optimizer.state.rst | 6 + .../mlx.optimizers.Optimizer.update.rst | 6 + .../mlx.optimizers.OptimizerState.rst | 17 - .../_autosummary/mlx.optimizers.RMSprop.rst | 1 + .../_autosummary/mlx.optimizers.SGD.rst | 1 + .../python/nn/_autosummary/mlx.nn.ALiBi.rst | 4 +- .../nn/_autosummary/mlx.nn.BatchNorm.rst | 4 +- .../python/nn/_autosummary/mlx.nn.Conv1d.rst | 4 +- .../python/nn/_autosummary/mlx.nn.Conv2d.rst | 4 +- .../python/nn/_autosummary/mlx.nn.Dropout.rst | 4 +- .../nn/_autosummary/mlx.nn.Dropout2d.rst | 4 +- .../nn/_autosummary/mlx.nn.Dropout3d.rst | 4 +- .../nn/_autosummary/mlx.nn.Embedding.rst | 4 +- .../python/nn/_autosummary/mlx.nn.GELU.rst | 4 +- .../nn/_autosummary/mlx.nn.GroupNorm.rst | 4 +- .../nn/_autosummary/mlx.nn.InstanceNorm.rst | 4 +- .../nn/_autosummary/mlx.nn.LayerNorm.rst | 4 +- .../python/nn/_autosummary/mlx.nn.Linear.rst | 4 +- .../python/nn/_autosummary/mlx.nn.Mish.rst | 4 +- .../nn/_autosummary/mlx.nn.Module.state.rst | 6 + .../mlx.nn.MultiHeadAttention.rst | 4 +- .../python/nn/_autosummary/mlx.nn.PReLU.rst | 4 +- .../_autosummary/mlx.nn.QuantizedLinear.rst | 4 +- .../python/nn/_autosummary/mlx.nn.RMSNorm.rst | 4 +- .../python/nn/_autosummary/mlx.nn.ReLU.rst | 4 +- .../python/nn/_autosummary/mlx.nn.RoPE.rst | 4 +- .../python/nn/_autosummary/mlx.nn.SELU.rst | 4 +- .../nn/_autosummary/mlx.nn.Sequential.rst | 4 +- .../python/nn/_autosummary/mlx.nn.SiLU.rst | 4 +- .../mlx.nn.SinusoidalPositionalEncoding.rst | 4 +- .../nn/_autosummary/mlx.nn.Softshrink.rst | 4 +- .../python/nn/_autosummary/mlx.nn.Step.rst | 4 +- .../nn/_autosummary/mlx.nn.Transformer.rst | 4 +- .../nn/_autosummary_functions/mlx.nn.gelu.rst | 4 +- .../mlx.nn.gelu_approx.rst | 4 +- .../mlx.nn.gelu_fast_approx.rst | 4 +- .../mlx.nn.init.constant.rst | 6 - .../mlx.nn.init.glorot_normal.rst | 6 - .../mlx.nn.init.glorot_uniform.rst | 6 - .../mlx.nn.init.he_normal.rst | 6 - .../mlx.nn.init.he_uniform.rst | 6 - .../mlx.nn.init.identity.rst | 6 - .../mlx.nn.init.normal.rst | 6 - .../mlx.nn.init.uniform.rst | 6 - .../mlx.nn.initializers.constant.rst | 6 - .../mlx.nn.initializers.glorot_normal.rst | 6 - .../mlx.nn.initializers.glorot_uniform.rst | 6 - .../mlx.nn.initializers.he_normal.rst | 6 - .../mlx.nn.initializers.he_uniform.rst | 6 - .../mlx.nn.initializers.identity.rst | 6 - .../mlx.nn.initializers.normal.rst | 6 - .../mlx.nn.initializers.uniform.rst | 6 - .../mlx.nn.losses.binary_cross_entropy.rst | 4 +- .../mlx.nn.losses.cosine_similarity_loss.rst | 4 +- .../mlx.nn.losses.cross_entropy.rst | 4 +- .../mlx.nn.losses.gaussian_nll_loss.rst | 4 +- .../mlx.nn.losses.hinge_loss.rst | 4 +- .../mlx.nn.losses.huber_loss.rst | 4 +- .../mlx.nn.losses.kl_div_loss.rst | 4 +- .../mlx.nn.losses.l1_loss.rst | 4 +- .../mlx.nn.losses.log_cosh_loss.rst | 4 +- .../mlx.nn.losses.margin_ranking_loss.rst | 6 + .../mlx.nn.losses.mse_loss.rst | 4 +- .../mlx.nn.losses.nll_loss.rst | 4 +- .../mlx.nn.losses.smooth_l1_loss.rst | 4 +- .../mlx.nn.losses.triplet_loss.rst | 4 +- .../nn/_autosummary_functions/mlx.nn.mish.rst | 4 +- .../_autosummary_functions/mlx.nn.prelu.rst | 4 +- .../nn/_autosummary_functions/mlx.nn.relu.rst | 4 +- .../nn/_autosummary_functions/mlx.nn.selu.rst | 4 +- .../nn/_autosummary_functions/mlx.nn.silu.rst | 4 +- .../mlx.nn.softshrink.rst | 4 +- .../nn/_autosummary_functions/mlx.nn.step.rst | 4 +- .../html/_sources/python/nn/initializers.rst | 18 - docs/build/html/_sources/python/nn/losses.rst | 1 + docs/build/html/_sources/python/nn/module.rst | 1 + docs/build/html/_sources/python/optimizer.rst | 23 + .../build/html/_sources/python/optimizers.rst | 6 +- .../build/html/_sources/python/transforms.rst | 3 + docs/build/html/_sources/usage/compile.rst | 430 ++++++ .../_sources/usage/function_transforms.rst | 17 +- .../html/_static/documentation_options.js | 2 +- docs/build/html/cpp/ops.html | 23 +- docs/build/html/dev/extensions.html | 37 +- .../html/examples/linear_regression.html | 23 +- docs/build/html/examples/llama-inference.html | 23 +- docs/build/html/examples/mlp.html | 23 +- docs/build/html/genindex.html | 103 +- docs/build/html/index.html | 24 +- docs/build/html/install.html | 23 +- docs/build/html/objects.inv | Bin 7553 -> 7646 bytes .../python/_autosummary/mlx.core.Device.html | 23 +- .../python/_autosummary/mlx.core.Dtype.html | 23 +- .../python/_autosummary/mlx.core.Stream.html | 23 +- .../python/_autosummary/mlx.core.abs.html | 23 +- .../python/_autosummary/mlx.core.add.html | 23 +- .../python/_autosummary/mlx.core.all.html | 23 +- .../_autosummary/mlx.core.allclose.html | 23 +- .../python/_autosummary/mlx.core.any.html | 23 +- .../python/_autosummary/mlx.core.arange.html | 23 +- .../python/_autosummary/mlx.core.arccos.html | 23 +- .../python/_autosummary/mlx.core.arccosh.html | 23 +- .../python/_autosummary/mlx.core.arcsin.html | 23 +- .../python/_autosummary/mlx.core.arcsinh.html | 23 +- .../python/_autosummary/mlx.core.arctan.html | 23 +- .../python/_autosummary/mlx.core.arctanh.html | 23 +- .../python/_autosummary/mlx.core.argmax.html | 23 +- .../python/_autosummary/mlx.core.argmin.html | 23 +- .../_autosummary/mlx.core.argpartition.html | 23 +- .../python/_autosummary/mlx.core.argsort.html | 23 +- .../python/_autosummary/mlx.core.array.T.html | 23 +- .../_autosummary/mlx.core.array.abs.html | 23 +- .../_autosummary/mlx.core.array.all.html | 23 +- .../_autosummary/mlx.core.array.any.html | 23 +- .../_autosummary/mlx.core.array.argmax.html | 23 +- .../_autosummary/mlx.core.array.argmin.html | 23 +- .../_autosummary/mlx.core.array.astype.html | 23 +- .../_autosummary/mlx.core.array.cos.html | 23 +- .../_autosummary/mlx.core.array.dtype.html | 23 +- .../_autosummary/mlx.core.array.exp.html | 23 +- .../python/_autosummary/mlx.core.array.html | 25 +- .../_autosummary/mlx.core.array.item.html | 23 +- .../_autosummary/mlx.core.array.log.html | 23 +- .../_autosummary/mlx.core.array.log1p.html | 23 +- .../mlx.core.array.logsumexp.html | 23 +- .../_autosummary/mlx.core.array.max.html | 23 +- .../_autosummary/mlx.core.array.mean.html | 23 +- .../_autosummary/mlx.core.array.min.html | 23 +- .../_autosummary/mlx.core.array.ndim.html | 23 +- .../_autosummary/mlx.core.array.prod.html | 23 +- .../mlx.core.array.reciprocal.html | 23 +- .../_autosummary/mlx.core.array.reshape.html | 23 +- .../_autosummary/mlx.core.array.round.html | 23 +- .../_autosummary/mlx.core.array.rsqrt.html | 23 +- .../_autosummary/mlx.core.array.shape.html | 25 +- .../_autosummary/mlx.core.array.sin.html | 23 +- .../_autosummary/mlx.core.array.size.html | 23 +- .../_autosummary/mlx.core.array.split.html | 23 +- .../_autosummary/mlx.core.array.sqrt.html | 23 +- .../_autosummary/mlx.core.array.square.html | 23 +- .../_autosummary/mlx.core.array.sum.html | 23 +- .../_autosummary/mlx.core.array.tolist.html | 23 +- .../mlx.core.array.transpose.html | 23 +- .../_autosummary/mlx.core.array.var.html | 23 +- .../_autosummary/mlx.core.array_equal.html | 23 +- .../_autosummary/mlx.core.broadcast_to.html | 23 +- .../python/_autosummary/mlx.core.ceil.html | 23 +- .../python/_autosummary/mlx.core.clip.html | 23 +- .../python/_autosummary/mlx.core.compile.html | 798 +++++++++++ .../_autosummary/mlx.core.concatenate.html | 23 +- .../python/_autosummary/mlx.core.conv1d.html | 23 +- .../python/_autosummary/mlx.core.conv2d.html | 23 +- .../_autosummary/mlx.core.convolve.html | 23 +- .../python/_autosummary/mlx.core.cos.html | 23 +- .../python/_autosummary/mlx.core.cosh.html | 23 +- .../_autosummary/mlx.core.default_device.html | 23 +- .../_autosummary/mlx.core.default_stream.html | 23 +- .../_autosummary/mlx.core.dequantize.html | 23 +- .../python/_autosummary/mlx.core.diag.html | 23 +- .../_autosummary/mlx.core.diagonal.html | 23 +- ...ify.html => mlx.core.disable_compile.html} | 101 +- .../python/_autosummary/mlx.core.divide.html | 23 +- .../python/_autosummary/mlx.core.divmod.html | 23 +- ...izer.html => mlx.core.enable_compile.html} | 87 +- .../python/_autosummary/mlx.core.equal.html | 23 +- .../python/_autosummary/mlx.core.erf.html | 23 +- .../python/_autosummary/mlx.core.erfinv.html | 23 +- .../python/_autosummary/mlx.core.eval.html | 33 +- .../python/_autosummary/mlx.core.exp.html | 23 +- .../_autosummary/mlx.core.expand_dims.html | 23 +- .../python/_autosummary/mlx.core.eye.html | 23 +- .../python/_autosummary/mlx.core.fft.fft.html | 23 +- .../_autosummary/mlx.core.fft.fft2.html | 23 +- .../_autosummary/mlx.core.fft.fftn.html | 23 +- .../_autosummary/mlx.core.fft.ifft.html | 23 +- .../_autosummary/mlx.core.fft.ifft2.html | 23 +- .../_autosummary/mlx.core.fft.ifftn.html | 23 +- .../_autosummary/mlx.core.fft.irfft.html | 23 +- .../_autosummary/mlx.core.fft.irfft2.html | 23 +- .../_autosummary/mlx.core.fft.irfftn.html | 23 +- .../_autosummary/mlx.core.fft.rfft.html | 23 +- .../_autosummary/mlx.core.fft.rfft2.html | 23 +- .../_autosummary/mlx.core.fft.rfftn.html | 23 +- .../python/_autosummary/mlx.core.flatten.html | 23 +- .../python/_autosummary/mlx.core.floor.html | 23 +- .../_autosummary/mlx.core.floor_divide.html | 23 +- .../python/_autosummary/mlx.core.full.html | 23 +- .../python/_autosummary/mlx.core.grad.html | 29 +- .../python/_autosummary/mlx.core.greater.html | 23 +- .../_autosummary/mlx.core.greater_equal.html | 23 +- .../_autosummary/mlx.core.identity.html | 23 +- .../python/_autosummary/mlx.core.inner.html | 23 +- .../python/_autosummary/mlx.core.isinf.html | 23 +- .../python/_autosummary/mlx.core.isnan.html | 23 +- .../_autosummary/mlx.core.isneginf.html | 23 +- .../_autosummary/mlx.core.isposinf.html | 23 +- .../python/_autosummary/mlx.core.jvp.html | 23 +- .../python/_autosummary/mlx.core.less.html | 23 +- .../_autosummary/mlx.core.less_equal.html | 23 +- .../_autosummary/mlx.core.linalg.norm.html | 23 +- .../_autosummary/mlx.core.linalg.qr.html | 23 +- .../_autosummary/mlx.core.linspace.html | 23 +- .../python/_autosummary/mlx.core.load.html | 23 +- .../python/_autosummary/mlx.core.log.html | 23 +- .../python/_autosummary/mlx.core.log10.html | 23 +- .../python/_autosummary/mlx.core.log1p.html | 23 +- .../python/_autosummary/mlx.core.log2.html | 23 +- .../_autosummary/mlx.core.logaddexp.html | 23 +- .../_autosummary/mlx.core.logical_and.html | 23 +- .../_autosummary/mlx.core.logical_not.html | 23 +- .../_autosummary/mlx.core.logical_or.html | 23 +- .../_autosummary/mlx.core.logsumexp.html | 23 +- .../python/_autosummary/mlx.core.matmul.html | 23 +- .../python/_autosummary/mlx.core.max.html | 23 +- .../python/_autosummary/mlx.core.maximum.html | 23 +- .../python/_autosummary/mlx.core.mean.html | 23 +- .../python/_autosummary/mlx.core.min.html | 23 +- .../python/_autosummary/mlx.core.minimum.html | 23 +- .../_autosummary/mlx.core.moveaxis.html | 23 +- .../_autosummary/mlx.core.multiply.html | 23 +- .../_autosummary/mlx.core.negative.html | 23 +- .../_autosummary/mlx.core.new_stream.html | 23 +- .../python/_autosummary/mlx.core.ones.html | 23 +- .../_autosummary/mlx.core.ones_like.html | 23 +- .../python/_autosummary/mlx.core.outer.html | 23 +- .../python/_autosummary/mlx.core.pad.html | 23 +- .../_autosummary/mlx.core.partition.html | 23 +- .../python/_autosummary/mlx.core.prod.html | 23 +- .../_autosummary/mlx.core.quantize.html | 23 +- .../mlx.core.quantized_matmul.html | 23 +- .../mlx.core.random.bernoulli.html | 23 +- .../mlx.core.random.categorical.html | 23 +- .../_autosummary/mlx.core.random.gumbel.html | 23 +- .../_autosummary/mlx.core.random.key.html | 23 +- .../_autosummary/mlx.core.random.normal.html | 27 +- .../_autosummary/mlx.core.random.randint.html | 23 +- .../_autosummary/mlx.core.random.seed.html | 23 +- .../_autosummary/mlx.core.random.split.html | 23 +- .../mlx.core.random.truncated_normal.html | 23 +- .../_autosummary/mlx.core.random.uniform.html | 23 +- .../_autosummary/mlx.core.reciprocal.html | 23 +- .../python/_autosummary/mlx.core.repeat.html | 23 +- .../python/_autosummary/mlx.core.reshape.html | 23 +- .../python/_autosummary/mlx.core.round.html | 23 +- .../python/_autosummary/mlx.core.rsqrt.html | 23 +- .../python/_autosummary/mlx.core.save.html | 23 +- .../_autosummary/mlx.core.save_gguf.html | 23 +- .../mlx.core.save_safetensors.html | 23 +- .../python/_autosummary/mlx.core.savez.html | 23 +- .../mlx.core.savez_compressed.html | 23 +- .../mlx.core.set_default_device.html | 23 +- .../mlx.core.set_default_stream.html | 23 +- .../python/_autosummary/mlx.core.sigmoid.html | 23 +- .../python/_autosummary/mlx.core.sign.html | 23 +- .../python/_autosummary/mlx.core.sin.html | 23 +- .../python/_autosummary/mlx.core.sinh.html | 23 +- .../python/_autosummary/mlx.core.softmax.html | 23 +- .../python/_autosummary/mlx.core.sort.html | 23 +- .../python/_autosummary/mlx.core.split.html | 23 +- .../python/_autosummary/mlx.core.sqrt.html | 23 +- .../python/_autosummary/mlx.core.square.html | 23 +- .../python/_autosummary/mlx.core.squeeze.html | 23 +- .../python/_autosummary/mlx.core.stack.html | 23 +- .../_autosummary/mlx.core.stop_gradient.html | 23 +- .../_autosummary/mlx.core.subtract.html | 23 +- .../python/_autosummary/mlx.core.sum.html | 23 +- .../_autosummary/mlx.core.swapaxes.html | 23 +- .../python/_autosummary/mlx.core.take.html | 23 +- .../mlx.core.take_along_axis.html | 23 +- .../python/_autosummary/mlx.core.tan.html | 23 +- .../python/_autosummary/mlx.core.tanh.html | 23 +- .../_autosummary/mlx.core.tensordot.html | 23 +- .../_autosummary/mlx.core.transpose.html | 23 +- .../python/_autosummary/mlx.core.tri.html | 23 +- .../python/_autosummary/mlx.core.tril.html | 23 +- .../python/_autosummary/mlx.core.triu.html | 23 +- .../_autosummary/mlx.core.value_and_grad.html | 23 +- .../python/_autosummary/mlx.core.var.html | 23 +- .../python/_autosummary/mlx.core.vjp.html | 23 +- .../python/_autosummary/mlx.core.vmap.html | 23 +- .../python/_autosummary/mlx.core.where.html | 23 +- .../python/_autosummary/mlx.core.zeros.html | 23 +- .../_autosummary/mlx.core.zeros_like.html | 23 +- .../_autosummary/mlx.nn.value_and_grad.html | 23 +- .../_autosummary/mlx.optimizers.AdaDelta.html | 26 +- .../mlx.optimizers.Adafactor.html | 26 +- .../_autosummary/mlx.optimizers.Adagrad.html | 26 +- .../_autosummary/mlx.optimizers.Adam.html | 26 +- .../_autosummary/mlx.optimizers.AdamW.html | 23 +- .../_autosummary/mlx.optimizers.Adamax.html | 26 +- .../_autosummary/mlx.optimizers.Lion.html | 26 +- ....optimizers.Optimizer.apply_gradients.html | 785 +++++++++++ .../mlx.optimizers.Optimizer.init.html | 792 +++++++++++ ...ml => mlx.optimizers.Optimizer.state.html} | 69 +- .../mlx.optimizers.Optimizer.update.html | 782 +++++++++++ .../_autosummary/mlx.optimizers.RMSprop.html | 26 +- .../_autosummary/mlx.optimizers.SGD.html | 32 +- .../_autosummary/mlx.utils.tree_flatten.html | 23 +- .../_autosummary/mlx.utils.tree_map.html | 23 +- .../mlx.utils.tree_unflatten.html | 23 +- docs/build/html/python/array.html | 25 +- docs/build/html/python/data_types.html | 23 +- .../html/python/devices_and_streams.html | 23 +- docs/build/html/python/fft.html | 23 +- docs/build/html/python/linalg.html | 23 +- docs/build/html/python/nn.html | 25 +- .../python/nn/_autosummary/mlx.nn.ALiBi.html | 23 +- .../nn/_autosummary/mlx.nn.BatchNorm.html | 23 +- .../python/nn/_autosummary/mlx.nn.Conv1d.html | 23 +- .../python/nn/_autosummary/mlx.nn.Conv2d.html | 23 +- .../nn/_autosummary/mlx.nn.Dropout.html | 23 +- .../nn/_autosummary/mlx.nn.Dropout2d.html | 23 +- .../nn/_autosummary/mlx.nn.Dropout3d.html | 23 +- .../nn/_autosummary/mlx.nn.Embedding.html | 23 +- .../python/nn/_autosummary/mlx.nn.GELU.html | 23 +- .../nn/_autosummary/mlx.nn.GroupNorm.html | 23 +- .../nn/_autosummary/mlx.nn.InstanceNorm.html | 23 +- .../nn/_autosummary/mlx.nn.LayerNorm.html | 23 +- .../python/nn/_autosummary/mlx.nn.Linear.html | 23 +- .../python/nn/_autosummary/mlx.nn.Mish.html | 23 +- .../nn/_autosummary/mlx.nn.Module.apply.html | 29 +- .../mlx.nn.Module.apply_to_modules.html | 23 +- .../_autosummary/mlx.nn.Module.children.html | 23 +- .../nn/_autosummary/mlx.nn.Module.eval.html | 23 +- .../mlx.nn.Module.filter_and_map.html | 23 +- .../nn/_autosummary/mlx.nn.Module.freeze.html | 23 +- .../mlx.nn.Module.leaf_modules.html | 23 +- .../mlx.nn.Module.load_weights.html | 23 +- .../_autosummary/mlx.nn.Module.modules.html | 23 +- .../mlx.nn.Module.named_modules.html | 23 +- .../mlx.nn.Module.parameters.html | 23 +- .../mlx.nn.Module.save_weights.html | 23 +- .../mlx.nn.Module.state.html} | 251 ++-- .../nn/_autosummary/mlx.nn.Module.train.html | 23 +- .../mlx.nn.Module.trainable_parameters.html | 23 +- .../_autosummary/mlx.nn.Module.training.html | 29 +- .../_autosummary/mlx.nn.Module.unfreeze.html | 23 +- .../nn/_autosummary/mlx.nn.Module.update.html | 23 +- .../mlx.nn.Module.update_modules.html | 23 +- .../mlx.nn.MultiHeadAttention.html | 23 +- .../python/nn/_autosummary/mlx.nn.PReLU.html | 23 +- .../_autosummary/mlx.nn.QuantizedLinear.html | 23 +- .../nn/_autosummary/mlx.nn.RMSNorm.html | 23 +- .../python/nn/_autosummary/mlx.nn.ReLU.html | 23 +- .../python/nn/_autosummary/mlx.nn.RoPE.html | 23 +- .../python/nn/_autosummary/mlx.nn.SELU.html | 23 +- .../nn/_autosummary/mlx.nn.Sequential.html | 23 +- .../python/nn/_autosummary/mlx.nn.SiLU.html | 23 +- .../mlx.nn.SinusoidalPositionalEncoding.html | 23 +- .../nn/_autosummary/mlx.nn.Softshrink.html | 23 +- .../python/nn/_autosummary/mlx.nn.Step.html | 23 +- .../nn/_autosummary/mlx.nn.Transformer.html | 23 +- .../nn/_autosummary/mlx.nn.init.constant.html | 23 +- .../mlx.nn.init.glorot_normal.html | 23 +- .../mlx.nn.init.glorot_uniform.html | 23 +- .../_autosummary/mlx.nn.init.he_normal.html | 23 +- .../_autosummary/mlx.nn.init.he_uniform.html | 23 +- .../nn/_autosummary/mlx.nn.init.identity.html | 23 +- .../nn/_autosummary/mlx.nn.init.normal.html | 23 +- .../nn/_autosummary/mlx.nn.init.uniform.html | 23 +- .../_autosummary_functions/mlx.nn.gelu.html | 27 +- .../mlx.nn.gelu_approx.html | 27 +- .../mlx.nn.gelu_fast_approx.html | 27 +- .../mlx.nn.init.constant.html | 748 ----------- .../mlx.nn.init.glorot_normal.html | 757 ----------- .../mlx.nn.init.glorot_uniform.html | 757 ----------- .../mlx.nn.init.he_normal.html | 761 ----------- .../mlx.nn.init.he_uniform.html | 761 ----------- .../mlx.nn.init.identity.html | 746 ----------- .../mlx.nn.initializers.constant.html | 720 ---------- .../mlx.nn.initializers.glorot_normal.html | 720 ---------- .../mlx.nn.initializers.glorot_uniform.html | 720 ---------- .../mlx.nn.initializers.he_normal.html | 720 ---------- .../mlx.nn.initializers.he_uniform.html | 720 ---------- .../mlx.nn.initializers.identity.html | 720 ---------- .../mlx.nn.initializers.normal.html | 720 ---------- .../mlx.nn.initializers.uniform.html | 720 ---------- .../mlx.nn.losses.binary_cross_entropy.html | 28 +- .../mlx.nn.losses.cosine_similarity_loss.html | 27 +- .../mlx.nn.losses.cross_entropy.html | 27 +- .../mlx.nn.losses.gaussian_nll_loss.html | 27 +- .../mlx.nn.losses.hinge_loss.html | 27 +- .../mlx.nn.losses.huber_loss.html | 27 +- .../mlx.nn.losses.kl_div_loss.html | 27 +- .../mlx.nn.losses.l1_loss.html | 27 +- .../mlx.nn.losses.log_cosh_loss.html | 33 +- ...=> mlx.nn.losses.margin_ranking_loss.html} | 124 +- .../mlx.nn.losses.mse_loss.html | 33 +- .../mlx.nn.losses.nll_loss.html | 27 +- .../mlx.nn.losses.smooth_l1_loss.html | 27 +- .../mlx.nn.losses.triplet_loss.html | 27 +- .../_autosummary_functions/mlx.nn.mish.html | 27 +- .../_autosummary_functions/mlx.nn.prelu.html | 27 +- .../_autosummary_functions/mlx.nn.relu.html | 27 +- .../_autosummary_functions/mlx.nn.selu.html | 27 +- .../_autosummary_functions/mlx.nn.silu.html | 27 +- .../mlx.nn.softshrink.html | 27 +- .../_autosummary_functions/mlx.nn.step.html | 27 +- docs/build/html/python/nn/functions.html | 23 +- docs/build/html/python/nn/init.html | 23 +- docs/build/html/python/nn/initializers.html | 778 ----------- docs/build/html/python/nn/layers.html | 23 +- docs/build/html/python/nn/losses.html | 35 +- docs/build/html/python/nn/module.html | 26 +- docs/build/html/python/ops.html | 23 +- docs/build/html/python/optimizer.html | 795 +++++++++++ docs/build/html/python/optimizers.html | 46 +- docs/build/html/python/random.html | 25 +- docs/build/html/python/transforms.html | 42 +- docs/build/html/python/tree_utils.html | 23 +- docs/build/html/search.html | 23 +- docs/build/html/searchindex.js | 2 +- docs/build/html/usage/compile.html | 1173 +++++++++++++++++ .../build/html/usage/function_transforms.html | 44 +- docs/build/html/usage/indexing.html | 23 +- docs/build/html/usage/lazy_evaluation.html | 23 +- docs/build/html/usage/numpy.html | 29 +- docs/build/html/usage/quick_start.html | 23 +- docs/build/html/usage/saving_and_loading.html | 23 +- docs/build/html/usage/unified_memory.html | 23 +- docs/build/html/usage/using_streams.html | 23 +- 437 files changed, 11568 insertions(+), 13689 deletions(-) create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.compile.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.disable_compile.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.enable_compile.rst delete mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.init.rst delete mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.state.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.update.rst delete mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.OptimizerState.rst create mode 100644 docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Module.state.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.constant.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_normal.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.identity.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.normal.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.uniform.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.constant.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.identity.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.normal.rst delete mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.rst create mode 100644 docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst delete mode 100644 docs/build/html/_sources/python/nn/initializers.rst create mode 100644 docs/build/html/_sources/python/optimizer.rst create mode 100644 docs/build/html/_sources/usage/compile.rst create mode 100644 docs/build/html/python/_autosummary/mlx.core.compile.html rename docs/build/html/python/_autosummary/{mlx.core.simplify.html => mlx.core.disable_compile.html} (89%) rename docs/build/html/python/_autosummary/{mlx.optimizers.Optimizer.html => mlx.core.enable_compile.html} (92%) create mode 100644 docs/build/html/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html create mode 100644 docs/build/html/python/_autosummary/mlx.optimizers.Optimizer.init.html rename docs/build/html/python/_autosummary/{mlx.optimizers.OptimizerState.html => mlx.optimizers.Optimizer.state.html} (94%) create mode 100644 docs/build/html/python/_autosummary/mlx.optimizers.Optimizer.update.html rename docs/build/html/python/nn/{_autosummary_functions/mlx.nn.init.normal.html => _autosummary/mlx.nn.Module.state.html} (78%) delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.constant.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_normal.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.identity.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.constant.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.identity.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.normal.html delete mode 100644 docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.html rename docs/build/html/python/nn/_autosummary_functions/{mlx.nn.init.uniform.html => mlx.nn.losses.margin_ranking_loss.html} (79%) delete mode 100644 docs/build/html/python/nn/initializers.html create mode 100644 docs/build/html/python/optimizer.html create mode 100644 docs/build/html/usage/compile.html diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 4c56d0ae4..0241e00e4 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: df22fdae6eaa6299681f0aab7c5d6029 +config: b49cb089891263e82aedf5bc4cacbe8a tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst index a7880e396..3563305bf 100644 --- a/docs/build/html/_sources/dev/extensions.rst +++ b/docs/build/html/_sources/dev/extensions.rst @@ -677,9 +677,9 @@ Let's look at the overall directory structure first. Binding to Python ^^^^^^^^^^^^^^^^^^ -We use PyBind11_ to build a Python API for the C++ library. Since bindings -for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc. -are already provided, adding our :meth:`axpby` becomes very simple! +We use PyBind11_ to build a Python API for the C++ library. Since bindings for +components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are +already provided, adding our :meth:`axpby` is simple! .. code-block:: C++ @@ -927,18 +927,18 @@ Results: We see some modest improvements right away! -This operation is now good to be used to build other operations, -in :class:`mlx.nn.Module` calls, and also as a part of graph -transformations like :meth:`grad`! +This operation is now good to be used to build other operations, in +:class:`mlx.nn.Module` calls, and also as a part of graph transformations like +:meth:`grad`! Scripts ------- .. admonition:: Download the code - The full example code is available in `mlx-examples `_. + The full example code is available in `mlx `_. -.. code: `TODO_LINK/extensions`_ +.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_ .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc .. _Metal: https://developer.apple.com/documentation/metal?language=objc diff --git a/docs/build/html/_sources/index.rst b/docs/build/html/_sources/index.rst index 4f4411758..50dfe9083 100644 --- a/docs/build/html/_sources/index.rst +++ b/docs/build/html/_sources/index.rst @@ -41,6 +41,7 @@ are the CPU and GPU. usage/indexing usage/saving_and_loading usage/function_transforms + usage/compile usage/numpy usage/using_streams diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.compile.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.compile.rst new file mode 100644 index 000000000..3ccea976d --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.compile.rst @@ -0,0 +1,6 @@ +mlx.core.compile +================ + +.. currentmodule:: mlx.core + +.. autofunction:: compile \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.disable_compile.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.disable_compile.rst new file mode 100644 index 000000000..913574b97 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.disable_compile.rst @@ -0,0 +1,6 @@ +mlx.core.disable\_compile +========================= + +.. currentmodule:: mlx.core + +.. autofunction:: disable_compile \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.enable_compile.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.enable_compile.rst new file mode 100644 index 000000000..c991ee8cb --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.enable_compile.rst @@ -0,0 +1,6 @@ +mlx.core.enable\_compile +======================== + +.. currentmodule:: mlx.core + +.. autofunction:: enable_compile \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst deleted file mode 100644 index c0b518497..000000000 --- a/docs/build/html/_sources/python/_autosummary/mlx.core.simplify.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.core.simplify -================= - -.. currentmodule:: mlx.core - -.. autofunction:: simplify \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst index 2ea7cda8a..55792c434 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.AdaDelta.rst @@ -14,5 +14,6 @@ ~AdaDelta.__init__ ~AdaDelta.apply_single + ~AdaDelta.init_single diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst index b0e5e5c30..9047eea41 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adafactor.rst @@ -14,5 +14,6 @@ ~Adafactor.__init__ ~Adafactor.apply_single + ~Adafactor.init_single diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst index 8a12fc43c..c12713e8a 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adagrad.rst @@ -14,5 +14,6 @@ ~Adagrad.__init__ ~Adagrad.apply_single + ~Adagrad.init_single diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst index 074080ea6..9ca26adfa 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adam.rst @@ -14,5 +14,6 @@ ~Adam.__init__ ~Adam.apply_single + ~Adam.init_single diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst index 58e6c95ca..73dc7314d 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Adamax.rst @@ -14,5 +14,6 @@ ~Adamax.__init__ ~Adamax.apply_single + ~Adamax.init_single diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst index a00dc50f0..1454aada1 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Lion.rst @@ -14,5 +14,6 @@ ~Lion.__init__ ~Lion.apply_single + ~Lion.init_single diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst new file mode 100644 index 000000000..763eeb293 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.rst @@ -0,0 +1,6 @@ +mlx.optimizers.Optimizer.apply\_gradients +========================================= + +.. currentmodule:: mlx.optimizers + +.. automethod:: Optimizer.apply_gradients \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.init.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.init.rst new file mode 100644 index 000000000..e0245cf02 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.init.rst @@ -0,0 +1,6 @@ +mlx.optimizers.Optimizer.init +============================= + +.. currentmodule:: mlx.optimizers + +.. automethod:: Optimizer.init \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.rst deleted file mode 100644 index 613eb02cf..000000000 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.rst +++ /dev/null @@ -1,20 +0,0 @@ -mlx.optimizers.Optimizer -======================== - -.. currentmodule:: mlx.optimizers - -.. autoclass:: Optimizer - - - - - .. rubric:: Methods - - .. autosummary:: - - ~Optimizer.__init__ - ~Optimizer.apply_gradients - ~Optimizer.apply_single - ~Optimizer.update - - diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.state.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.state.rst new file mode 100644 index 000000000..e0bf31dbe --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.state.rst @@ -0,0 +1,6 @@ +mlx.optimizers.Optimizer.state +============================== + +.. currentmodule:: mlx.optimizers + +.. autoproperty:: Optimizer.state \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.update.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.update.rst new file mode 100644 index 000000000..e7610999e --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.Optimizer.update.rst @@ -0,0 +1,6 @@ +mlx.optimizers.Optimizer.update +=============================== + +.. currentmodule:: mlx.optimizers + +.. automethod:: Optimizer.update \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.OptimizerState.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.OptimizerState.rst deleted file mode 100644 index b319b6d09..000000000 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.OptimizerState.rst +++ /dev/null @@ -1,17 +0,0 @@ -mlx.optimizers.OptimizerState -============================= - -.. currentmodule:: mlx.optimizers - -.. autoclass:: OptimizerState - - - - - .. rubric:: Methods - - .. autosummary:: - - ~OptimizerState.get - - diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst index 217b4619f..d9ba20078 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.RMSprop.rst @@ -14,5 +14,6 @@ ~RMSprop.__init__ ~RMSprop.apply_single + ~RMSprop.init_single diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst index 35a9759ed..4b6f397ec 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.SGD.rst @@ -14,5 +14,6 @@ ~SGD.__init__ ~SGD.apply_single + ~SGD.init_single diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst index 284b453cf..9159bb888 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ALiBi.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: ALiBi - - \ No newline at end of file +.. autoclass:: ALiBi \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst index b94d82e7f..d085d5af5 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.BatchNorm.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: BatchNorm - - \ No newline at end of file +.. autoclass:: BatchNorm \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst index c4128b83b..0fb6ff201 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv1d.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Conv1d - - \ No newline at end of file +.. autoclass:: Conv1d \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst index 7bd1f08bb..566e5d1e1 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Conv2d.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Conv2d - - \ No newline at end of file +.. autoclass:: Conv2d \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst index d1a68e793..2ec3556e1 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Dropout - - \ No newline at end of file +.. autoclass:: Dropout \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst index 8bf18deb8..d643adcb9 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout2d.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Dropout2d - - \ No newline at end of file +.. autoclass:: Dropout2d \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst index d513a3d61..f386030ee 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Dropout3d.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Dropout3d - - \ No newline at end of file +.. autoclass:: Dropout3d \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst index ad2f3f2ce..0f29f593d 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Embedding.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Embedding - - \ No newline at end of file +.. autoclass:: Embedding \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst index c963c84f2..c6ca7a28c 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GELU.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: GELU - - \ No newline at end of file +.. autoclass:: GELU \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst index 762d9ffea..982103df5 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.GroupNorm.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: GroupNorm - - \ No newline at end of file +.. autoclass:: GroupNorm \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst index 92152b356..66d01967f 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.InstanceNorm.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: InstanceNorm - - \ No newline at end of file +.. autoclass:: InstanceNorm \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst index cc0ac117d..817f9551e 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.LayerNorm.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: LayerNorm - - \ No newline at end of file +.. autoclass:: LayerNorm \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst index 627e6e6e6..53be170e4 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Linear.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Linear - - \ No newline at end of file +.. autoclass:: Linear \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst index bf5397852..bd10864be 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Mish.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Mish - - \ No newline at end of file +.. autoclass:: Mish \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Module.state.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Module.state.rst new file mode 100644 index 000000000..7f4819837 --- /dev/null +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Module.state.rst @@ -0,0 +1,6 @@ +mlx.nn.Module.state +=================== + +.. currentmodule:: mlx.nn + +.. autoproperty:: Module.state \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst index 2c3a8fcc1..0a3f8d184 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: MultiHeadAttention - - \ No newline at end of file +.. autoclass:: MultiHeadAttention \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst index 2de33a688..4583c2d65 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.PReLU.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: PReLU - - \ No newline at end of file +.. autoclass:: PReLU \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst index ccbde4340..00688282e 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.QuantizedLinear.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: QuantizedLinear - - \ No newline at end of file +.. autoclass:: QuantizedLinear \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst index 474b1355d..d4501bd36 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RMSNorm.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: RMSNorm - - \ No newline at end of file +.. autoclass:: RMSNorm \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst index 944917de9..6707e757e 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.ReLU.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: ReLU - - \ No newline at end of file +.. autoclass:: ReLU \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst index 392fbab7b..fca09a4eb 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.RoPE.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: RoPE - - \ No newline at end of file +.. autoclass:: RoPE \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst index 9fe57cdea..fa7477246 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SELU.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: SELU - - \ No newline at end of file +.. autoclass:: SELU \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst index af6ee04ab..5ae61b025 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Sequential.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Sequential - - \ No newline at end of file +.. autoclass:: Sequential \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst index 85069c9d5..57d18df4f 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SiLU.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: SiLU - - \ No newline at end of file +.. autoclass:: SiLU \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst index bfdd633a5..30b7a1f90 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: SinusoidalPositionalEncoding - - \ No newline at end of file +.. autoclass:: SinusoidalPositionalEncoding \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst index 464c3451b..5e17a5199 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Softshrink.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Softshrink - - \ No newline at end of file +.. autoclass:: Softshrink \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst index 688313628..204f8cbb1 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Step.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Step - - \ No newline at end of file +.. autoclass:: Step \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst index 01dc3a841..f7e800eff 100644 --- a/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst +++ b/docs/build/html/_sources/python/nn/_autosummary/mlx.nn.Transformer.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: Transformer - - \ No newline at end of file +.. autoclass:: Transformer \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst index 3e1668eb6..616cb1e22 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: gelu - - \ No newline at end of file +.. autofunction:: gelu \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst index de08dc88c..d634ee1de 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: gelu_approx - - \ No newline at end of file +.. autofunction:: gelu_approx \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst index c84114e6c..36cc04480 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: gelu_fast_approx - - \ No newline at end of file +.. autofunction:: gelu_fast_approx \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.constant.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.constant.rst deleted file mode 100644 index e61e02905..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.constant.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.constant -==================== - -.. currentmodule:: mlx.nn.init - -.. autofunction:: constant \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.rst deleted file mode 100644 index b500f578d..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.glorot\_normal -========================== - -.. currentmodule:: mlx.nn.init - -.. autofunction:: glorot_normal \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.rst deleted file mode 100644 index b266fc94a..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.glorot\_uniform -=========================== - -.. currentmodule:: mlx.nn.init - -.. autofunction:: glorot_uniform \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_normal.rst deleted file mode 100644 index 51c3287a7..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_normal.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.he\_normal -====================== - -.. currentmodule:: mlx.nn.init - -.. autofunction:: he_normal \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.rst deleted file mode 100644 index ee299e247..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.he\_uniform -======================= - -.. currentmodule:: mlx.nn.init - -.. autofunction:: he_uniform \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.identity.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.identity.rst deleted file mode 100644 index a5772adfa..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.identity.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.identity -==================== - -.. currentmodule:: mlx.nn.init - -.. autofunction:: identity \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.normal.rst deleted file mode 100644 index 6f9ce0023..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.normal.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.normal -================== - -.. currentmodule:: mlx.nn.init - -.. autofunction:: normal \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.uniform.rst deleted file mode 100644 index 7d3b82560..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.init.uniform.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.init.uniform -=================== - -.. currentmodule:: mlx.nn.init - -.. autofunction:: uniform \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.constant.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.constant.rst deleted file mode 100644 index 7e983ec9c..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.constant.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.constant -============================ - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: constant \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.rst deleted file mode 100644 index 1860f0f1a..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.glorot\_normal -================================== - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: glorot_normal \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.rst deleted file mode 100644 index 1693bb019..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.glorot\_uniform -=================================== - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: glorot_uniform \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.rst deleted file mode 100644 index 76e5d0ac7..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.he\_normal -============================== - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: he_normal \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.rst deleted file mode 100644 index 7482519a1..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.he\_uniform -=============================== - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: he_uniform \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.identity.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.identity.rst deleted file mode 100644 index 8548c4439..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.identity.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.identity -============================ - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: identity \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.normal.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.normal.rst deleted file mode 100644 index 3e82a3645..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.normal.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.normal -========================== - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: normal \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.rst deleted file mode 100644 index 28c504bd1..000000000 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.rst +++ /dev/null @@ -1,6 +0,0 @@ -mlx.nn.initializers.uniform -=========================== - -.. currentmodule:: mlx.nn.initializers - -.. autofunction:: uniform \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst index be553e4c0..ba5254eff 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: binary_cross_entropy - - \ No newline at end of file +.. autofunction:: binary_cross_entropy \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst index 7970aaca7..27b5d4a8e 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: cosine_similarity_loss - - \ No newline at end of file +.. autofunction:: cosine_similarity_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst index 9c50fd349..fd7f9e6f6 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: cross_entropy - - \ No newline at end of file +.. autofunction:: cross_entropy \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst index 63cc52978..a481e2317 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: gaussian_nll_loss - - \ No newline at end of file +.. autofunction:: gaussian_nll_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst index 3b94ae64c..092dcd383 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: hinge_loss - - \ No newline at end of file +.. autofunction:: hinge_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst index 5b5dc918e..da5e4d417 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: huber_loss - - \ No newline at end of file +.. autofunction:: huber_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst index 11e070650..04d2fcce3 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: kl_div_loss - - \ No newline at end of file +.. autofunction:: kl_div_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst index 34ae66d69..950aff725 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: l1_loss - - \ No newline at end of file +.. autofunction:: l1_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst index b00c1a51f..b7a7461c9 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: log_cosh_loss - - \ No newline at end of file +.. autofunction:: log_cosh_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst new file mode 100644 index 000000000..ede1b9f0a --- /dev/null +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.rst @@ -0,0 +1,6 @@ +mlx.nn.losses.margin\_ranking\_loss +=================================== + +.. currentmodule:: mlx.nn.losses + +.. autofunction:: margin_ranking_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst index 534ed1e14..c4e44ddb1 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: mse_loss - - \ No newline at end of file +.. autofunction:: mse_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst index c94eb82a1..e64b55dfc 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: nll_loss - - \ No newline at end of file +.. autofunction:: nll_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst index 00a647a75..d96bb5823 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: smooth_l1_loss - - \ No newline at end of file +.. autofunction:: smooth_l1_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst index 4698d6155..f52eaab92 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn.losses -.. autoclass:: triplet_loss - - \ No newline at end of file +.. autofunction:: triplet_loss \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst index 85bf0899b..49c1cfcb9 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.mish.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: mish - - \ No newline at end of file +.. autofunction:: mish \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst index f3757c1c3..51085ec23 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.prelu.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: prelu - - \ No newline at end of file +.. autofunction:: prelu \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst index 93a69272a..f1c28d7aa 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.relu.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: relu - - \ No newline at end of file +.. autofunction:: relu \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst index 00c1d0923..f1530a805 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.selu.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: selu - - \ No newline at end of file +.. autofunction:: selu \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst index b30c17b06..cd5ff218e 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.silu.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: silu - - \ No newline at end of file +.. autofunction:: silu \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst index e6af930b6..b844f9242 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.softshrink.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: softshrink - - \ No newline at end of file +.. autofunction:: softshrink \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst index 1395bd012..0ad2c19e6 100644 --- a/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst +++ b/docs/build/html/_sources/python/nn/_autosummary_functions/mlx.nn.step.rst @@ -3,6 +3,4 @@ .. currentmodule:: mlx.nn -.. autoclass:: step - - \ No newline at end of file +.. autofunction:: step \ No newline at end of file diff --git a/docs/build/html/_sources/python/nn/initializers.rst b/docs/build/html/_sources/python/nn/initializers.rst deleted file mode 100644 index 59dddbe22..000000000 --- a/docs/build/html/_sources/python/nn/initializers.rst +++ /dev/null @@ -1,18 +0,0 @@ -.. _initializers: - -.. currentmodule:: mlx.nn.initializers - -Initializers --------------- - -.. autosummary:: - :toctree: _autosummary_functions - - constant - normal - uniform - identity - glorot_normal - glorot_uniform - he_normal - he_uniform diff --git a/docs/build/html/_sources/python/nn/losses.rst b/docs/build/html/_sources/python/nn/losses.rst index 6c4327eb8..6a2e128c5 100644 --- a/docs/build/html/_sources/python/nn/losses.rst +++ b/docs/build/html/_sources/python/nn/losses.rst @@ -18,6 +18,7 @@ Loss Functions kl_div_loss l1_loss log_cosh_loss + margin_ranking_loss mse_loss nll_loss smooth_l1_loss diff --git a/docs/build/html/_sources/python/nn/module.rst b/docs/build/html/_sources/python/nn/module.rst index 042a88028..c3a4dfa62 100644 --- a/docs/build/html/_sources/python/nn/module.rst +++ b/docs/build/html/_sources/python/nn/module.rst @@ -11,6 +11,7 @@ Module :toctree: _autosummary Module.training + Module.state .. rubric:: Methods diff --git a/docs/build/html/_sources/python/optimizer.rst b/docs/build/html/_sources/python/optimizer.rst new file mode 100644 index 000000000..cf6034dee --- /dev/null +++ b/docs/build/html/_sources/python/optimizer.rst @@ -0,0 +1,23 @@ +Optimizer +========= + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Optimizer + + + .. rubric:: Attributes + + .. autosummary:: + :toctree: _autosummary + + Optimizer.state + + .. rubric:: Methods + + .. autosummary:: + :toctree: _autosummary + + Optimizer.apply_gradients + Optimizer.init + Optimizer.update diff --git a/docs/build/html/_sources/python/optimizers.rst b/docs/build/html/_sources/python/optimizers.rst index fe8632a7e..4ef43d50f 100644 --- a/docs/build/html/_sources/python/optimizers.rst +++ b/docs/build/html/_sources/python/optimizers.rst @@ -29,14 +29,16 @@ model's parameters and the **optimizer state**. # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) +.. toctree:: + + optimizer + .. currentmodule:: mlx.optimizers .. autosummary:: :toctree: _autosummary :template: optimizers-template.rst - OptimizerState - Optimizer SGD RMSprop Adagrad diff --git a/docs/build/html/_sources/python/transforms.rst b/docs/build/html/_sources/python/transforms.rst index cc8d681d5..ad9ba579b 100644 --- a/docs/build/html/_sources/python/transforms.rst +++ b/docs/build/html/_sources/python/transforms.rst @@ -9,6 +9,9 @@ Transforms :toctree: _autosummary eval + compile + disable_compile + enable_compile grad value_and_grad jvp diff --git a/docs/build/html/_sources/usage/compile.rst b/docs/build/html/_sources/usage/compile.rst new file mode 100644 index 000000000..97d5503a3 --- /dev/null +++ b/docs/build/html/_sources/usage/compile.rst @@ -0,0 +1,430 @@ +.. _compile: + +Compilation +=========== + +.. currentmodule:: mlx.core + +MLX has a :func:`compile` function transformation which compiles computation +graphs. Function compilation results in smaller graphs by merging common work +and fusing certain operations. In many cases this can lead to big improvements +in run-time and memory use. + +Getting started with :func:`compile` is simple, but there are some edge cases +that are good to be aware of for more complex graphs and advanced usage. + +Basics of Compile +----------------- + +Let's start with a simple example: + +.. code-block:: python + + def fun(x, y): + return mx.exp(-x) + y + + x = mx.array(1.0) + y = mx.array(2.0) + + # Regular call, no compilation + # Prints: array(2.36788, dtype=float32) + print(fun(x, y)) + + # Compile the function + compiled_fun = mx.compile(fun) + + # Prints: array(2.36788, dtype=float32) + print(compiled_fun(x, y)) + +The output of both the regular function and the compiled function is the same +up to numerical precision. + +The first time you call a compiled function, MLX will build the compute +graph, optimize it, and generate and compile code. This can be relatively +slow. However, MLX will cache compiled functions, so calling a compiled +function multiple times will not initiate a new compilation. This means you +should typically compile functions that you plan to use more than once. + +.. code-block:: python + + def fun(x, y): + return mx.exp(-x) + y + + x = mx.array(1.0) + y = mx.array(2.0) + + compiled_fun = mx.compile(fun) + + # Compiled here + compiled_fun(x, y) + + # Not compiled again + compiled_fun(x, y) + + # Not compiled again + mx.compile(fun)(x, y) + +There are some important cases to be aware of that can cause a function to +be recompiled: + +* Changing the shape or number of dimensions +* Changing the type of any of the inputs +* Changing the number of inputs to the function + +In certain cases only some of the compilation stack will be rerun (for +example when changing the shapes) and in other cases the full compilation +stack will be rerun (for example when changing the types). In general you +should avoid compiling functions too frequently. + +Another idiom to watch out for is compiling functions which get created and +destroyed frequently. This can happen, for example, when compiling an anonymous +function in a loop: + +.. code-block:: python + + a = mx.array(1.0) + # Don't do this, compiles lambda at each iteration + for _ in range(5): + mx.compile(lambda x: mx.exp(mx.abs(x)))(a) + +Example Speedup +--------------- + +The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with +Transformer-based models. The implementation involves several unary and binary +element-wise operations: + +.. code-block:: python + + def gelu(x): + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + +If you use this function with small arrays, it will be overhead bound. If you +use it with large arrays it will be memory bandwidth bound. However, all of +the operations in the ``gelu`` are fusible into a single kernel with +:func:`compile`. This can speedup both cases considerably. + +Let's compare the runtime of the regular function versus the compiled +function. We'll use the following timing helper which does a warm up and +handles synchronization: + +.. code-block:: python + + import time + + def timeit(fun, x): + # warm up + for _ in range(10): + mx.eval(fun(x)) + + tic = time.perf_counter() + for _ in range(100): + mx.eval(fun(x)) + toc = time.perf_counter() + tpi = 1e3 * (toc - tic) / 100 + print(f"Time per iteration {tpi:.3f} (ms)") + + +Now make an array, and benchmark both functions: + +.. code-block:: python + + x = mx.random.uniform(shape=(32, 1000, 4096)) + timeit(nn.gelu, x) + timeit(mx.compile(nn.gelu), x) + +On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is +five times faster. + +.. note:: + + As of the latest MLX, CPU functions are not fully compiled. Compiling CPU + functions can still be helpful, but won't typically result in as large a + speedup as compiling operations that run on the GPU. + + +Debugging +--------- + +When a compiled function is first called, it is traced with placeholder +inputs. This means you can't evaluate arrays (for example to print their +contents) inside compiled functions. + +.. code-block:: python + + @mx.compile + def fun(x): + z = -x + print(z) # Crash + return mx.exp(z) + + fun(mx.array(5.0)) + +For debugging, inspecting arrays can be helpful. One way to do that is to +globally disable compilation using the :func:`disable_compile` function or +``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though +``fun`` is compiled: + +.. code-block:: python + + @mx.compile + def fun(x): + z = -x + print(z) # Okay + return mx.exp(z) + + mx.disable_compile() + fun(mx.array(5.0)) + + +Pure Functions +-------------- + +Compiled functions are intended to be *pure*; that is they should not have side +effects. For example: + +.. code-block:: python + + state = [] + + @mx.compile + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z) + + fun(mx.array(1.0), mx.array(2.0)) + # Crash! + print(state) + +After the first call of ``fun``, the ``state`` list will hold a placeholder +array. The placeholder does not have any data; it is only used to build the +computation graph. Printing such an array results in a crash. + +You have two options to deal with this. The first option is to simply return +``state`` as an output: + +.. code-block:: python + + state = [] + + @mx.compile + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z), state + + _, state = fun(mx.array(1.0), mx.array(2.0)) + # Prints [array(3, dtype=float32)] + print(state) + +In some cases returning updated state can be pretty inconvenient. Hence, +:func:`compile` has a parameter to capture implicit outputs: + +.. code-block:: python + + from functools import partial + + state = [] + + # Tell compile to capture state as an output + @partial(mx.compile, outputs=state) + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z), state + + fun(mx.array(1.0), mx.array(2.0)) + # Prints [array(3, dtype=float32)] + print(state) + +This is particularly useful for compiling a function which includes an update +to a container of arrays, as is commonly done when training the parameters of a +:class:`mlx.nn.Module`. + +Compiled functions will also treat any inputs not in the parameter list as +constants. For example: + +.. code-block:: python + + state = [mx.array(1.0)] + + @mx.compile + def fun(x): + return x + state[0] + + # Prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + + # Update state + state[0] = mx.array(5.0) + + # Still prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + +In order to have the change of state reflected in the outputs of ``fun`` you +again have two options. The first option is to simply pass ``state`` as input +to the function. In some cases this can be pretty inconvenient. Hence, +:func:`compile` also has a parameter to capture implicit inputs: + +.. code-block:: python + + from functools import partial + state = [mx.array(1.0)] + + # Tell compile to capture state as an input + @partial(mx.compile, inputs=state) + def fun(x): + return x + state[0] + + # Prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + + # Update state + state[0] = mx.array(5.0) + + # Prints array(6, dtype=float32) + print(fun(mx.array(1.0))) + + +Compiling Training Graphs +------------------------- + +This section will step through how to use :func:`compile` with a simple example +of a common setup: training a model with :obj:`mlx.nn.Module` using an +:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the +full forward, backward, and update with :func:`compile`. + +To start, here is the simple example without any compilation: + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + import mlx.optimizers as optim + + # 4 examples with 10 features each + x = mx.random.uniform(shape=(4, 10)) + + # 0, 1 targets + y = mx.array([0, 1, 0, 1]) + + # Simple linear model + model = nn.Linear(10, 1) + + # SGD with momentum + optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) + + def loss_fn(model, x, y): + logits = model(x).squeeze() + return nn.losses.binary_cross_entropy(logits, y) + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + + # Perform 10 steps of gradient descent + for it in range(10): + loss, grads = loss_and_grad_fn(model, x, y) + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + +To compile the update we can put it all in a function and compile it with the +appropriate input and output captures. Here's the same example but compiled: + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + import mlx.optimizers as optim + from functools import partial + + # 4 examples with 10 features each + x = mx.random.uniform(shape=(4, 10)) + + # 0, 1 targets + y = mx.array([0, 1, 0, 1]) + + # Simple linear model + model = nn.Linear(10, 1) + + # SGD with momentum + optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) + + def loss_fn(model, x, y): + logits = model(x).squeeze() + return nn.losses.binary_cross_entropy(logits, y) + + # The state that will be captured as input and output + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x, y): + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, x, y) + optimizer.update(model, grads) + return loss + + # Perform 10 steps of gradient descent + for it in range(10): + loss = step(x, y) + # Evaluate the model and optimizer state + mx.eval(state) + print(loss) + + +.. note:: + + If you are using a module which performs random sampling such as + :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the + ``state`` captured by :func:`compile`, i.e. ``state = [model.state, + optimizer.state, mx.random.state]``. + + +.. note:: + + For more examples of compiling full training graphs checkout the `MLX + Examples `_ GitHub repo. + +Transformations with Compile +---------------------------- + +In MLX function transformations are composable. You can apply any function +transformation to the output of any other function transformation. For more on +this, see the documentation on :ref:`function transforms +`. + +Compiling transformed functions works just as expected: + +.. code-block:: python + + grad_fn = mx.grad(mx.exp) + + compiled_grad_fn = mx.compile(grad_fn) + + # Prints: array(2.71828, dtype=float32) + print(grad_fn(mx.array(1.0))) + + # Also prints: array(2.71828, dtype=float32) + print(compiled_grad_fn(mx.array(1.0))) + +.. note:: + + In order to compile as much as possible, a transformation of a compiled + function will not by default be compiled. To compile the transformed + function simply pass it through :func:`compile`. + +You can also compile functions which themselves call compiled functions. A +good practice is to compile the outer most function to give :func:`compile` +the most opportunity to optimize the computation graph: + +.. code-block:: python + + @mx.compile + def inner(x): + return mx.exp(-mx.abs(x)) + + def outer(x): + inner(inner(x)) + + # Compiling the outer function is good to do as it will likely + # be faster even though the inner functions are compiled + fun = mx.compile(outer) diff --git a/docs/build/html/_sources/usage/function_transforms.rst b/docs/build/html/_sources/usage/function_transforms.rst index 72a313f97..02c5dec48 100644 --- a/docs/build/html/_sources/usage/function_transforms.rst +++ b/docs/build/html/_sources/usage/function_transforms.rst @@ -5,9 +5,12 @@ Function Transforms .. currentmodule:: mlx.core -MLX uses composable function transformations for automatic differentiation and -vectorization. The key idea behind composable function transformations is that -every transformation returns a function which can be further transformed. +MLX uses composable function transformations for automatic differentiation, +vectorization, and compute graph optimizations. To see the complete list of +function transformations check-out the :ref:`API documentation `. + +The key idea behind composable function transformations is that every +transformation returns a function which can be further transformed. Here is a simple example: @@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep getting higher order derivatives. Any of the MLX function transformations can be composed in any order to any -depth. To see the complete list of function transformations check-out the -:ref:`API documentation `. See the following sections for more -information on :ref:`automatic differentiaion ` and -:ref:`automatic vectorization `. +depth. See the following sections for more information on :ref:`automatic +differentiaion ` and :ref:`automatic vectorization `. +For more information on :func:`compile` see the :ref:`compile documentation `. + Automatic Differentiation ------------------------- diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index 5404195e5..ff2bcd035 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,6 +1,6 @@ var DOCUMENTATION_OPTIONS = { URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), - VERSION: '0.1.0', + VERSION: '0.2.0', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html index 81be1fd39..fc78b923d 100644 --- a/docs/build/html/cpp/ops.html +++ b/docs/build/html/cpp/ops.html @@ -9,7 +9,7 @@ - Operations — MLX 0.1.0 documentation + Operations — MLX 0.2.0 documentation @@ -134,8 +134,8 @@ - MLX 0.1.0 documentation - Home - + MLX 0.2.0 documentation - Home + @@ -153,6 +153,7 @@
  • Indexing Arrays
  • Saving and Loading Arrays
  • Function Transforms
  • +
  • Compilation
  • Conversion to NumPy and Other Frameworks
  • Using Streams
  • @@ -348,6 +349,9 @@
  • Transforms