From 20a3e22ff042b599457aa0a1a83c97819f707e84 Mon Sep 17 00:00:00 2001
From: Awni Hannun 
@@ -335,48 +345,63 @@
 
 
+
+
+
 
 
@@ -334,48 +344,63 @@
 
 
+
+
+
 
 We see some modest improvements right away!
 This operation is now good to be used to build other operations,
-in mlx.nn.Module calls, and also as a part of graph
+in mlx.nn.Module calls, and also as a part of graph
 transformations such as grad() and simplify()!
 
 
diff --git a/docs/build/html/examples/linear_regression.html b/docs/build/html/examples/linear_regression.html
index dadfc4423..f81c9f4cd 100644
--- a/docs/build/html/examples/linear_regression.html
+++ b/docs/build/html/examples/linear_regression.html
@@ -226,6 +226,7 @@
 mlx.core.argsort 
 mlx.core.array_equal 
 mlx.core.broadcast_to 
+mlx.core.ceil 
 mlx.core.concatenate 
 mlx.core.convolve 
 mlx.core.conv1d 
@@ -239,6 +240,8 @@
 mlx.core.exp 
 mlx.core.expand_dims 
 mlx.core.eye 
+mlx.core.floor 
+mlx.core.flatten 
 mlx.core.full 
 mlx.core.greater 
 mlx.core.greater_equal 
@@ -259,6 +262,7 @@
 mlx.core.mean 
 mlx.core.min 
 mlx.core.minimum 
+mlx.core.moveaxis 
 mlx.core.multiply 
 mlx.core.negative 
 mlx.core.ones 
@@ -282,14 +286,19 @@
 mlx.core.sqrt 
 mlx.core.square 
 mlx.core.squeeze 
+mlx.core.stack 
 mlx.core.stop_gradient 
 mlx.core.subtract 
 mlx.core.sum 
+mlx.core.swapaxes 
 mlx.core.take 
 mlx.core.take_along_axis 
 mlx.core.tan 
 mlx.core.tanh 
 mlx.core.transpose 
+mlx.core.tri 
+mlx.core.tril 
+mlx.core.triu 
 mlx.core.var 
 mlx.core.where 
 mlx.core.zeros 
@@ -316,6 +325,7 @@
 mlx.core.jvp 
 mlx.core.vjp 
 mlx.core.vmap 
+mlx.core.simplify 
 
 
 FFT
@@ -335,48 +345,63 @@
 
 
 Neural Networks
 - mlx.nn.value_and_grad
 
-- mlx.nn.Embedding
 
-- mlx.nn.ReLU
 
-- mlx.nn.PReLU
 
-- mlx.nn.GELU
 
-- mlx.nn.SiLU
 
-- mlx.nn.Step
 
-- mlx.nn.SELU
 
-- mlx.nn.Mish
 
-- mlx.nn.Linear
 
-- mlx.nn.Conv1d
 
-- mlx.nn.Conv2d
 
-- mlx.nn.LayerNorm
 
-- mlx.nn.RMSNorm
 
-- mlx.nn.GroupNorm
 
-- mlx.nn.RoPE
 
-- mlx.nn.MultiHeadAttention
 
-- mlx.nn.Sequential
 
-- mlx.nn.gelu
 
-- mlx.nn.gelu_approx
 
-- mlx.nn.gelu_fast_approx
 
-- mlx.nn.relu
 
-- mlx.nn.prelu
 
-- mlx.nn.silu
 
-- mlx.nn.step
 
-- mlx.nn.selu
 
-- mlx.nn.mish
 
-- mlx.nn.losses.cross_entropy
 
-- mlx.nn.losses.binary_cross_entropy
 
-- mlx.nn.losses.l1_loss
 
-- mlx.nn.losses.mse_loss
 
-- mlx.nn.losses.nll_loss
 
-- mlx.nn.losses.kl_div_loss
 
+- mlx.nn.Module
 
+- Layers
 
 
-- Optimizers
+- Functions
+
 
+- Loss Functions
+
 
+
+ 
+- Optimizers
 
 
-- Tree Utils
+- Tree Utils
 - mlx.utils.tree_flatten
 
 - mlx.utils.tree_unflatten
 
 - mlx.utils.tree_map
 
diff --git a/docs/build/html/examples/llama-inference.html b/docs/build/html/examples/llama-inference.html
index 7ad78bcad..f102d4087 100644
--- a/docs/build/html/examples/llama-inference.html
+++ b/docs/build/html/examples/llama-inference.html
@@ -226,6 +226,7 @@
 - mlx.core.argsort
 
 - mlx.core.array_equal
 
 - mlx.core.broadcast_to
 
+- mlx.core.ceil
 
 - mlx.core.concatenate
 
 - mlx.core.convolve
 
 - mlx.core.conv1d
 
@@ -239,6 +240,8 @@
 - mlx.core.exp
 
 - mlx.core.expand_dims
 
 - mlx.core.eye
 
+- mlx.core.floor
 
+- mlx.core.flatten
 
 - mlx.core.full
 
 - mlx.core.greater
 
 - mlx.core.greater_equal
 
@@ -259,6 +262,7 @@
 - mlx.core.mean
 
 - mlx.core.min
 
 - mlx.core.minimum
 
+- mlx.core.moveaxis
 
 - mlx.core.multiply
 
 - mlx.core.negative
 
 - mlx.core.ones
 
@@ -282,14 +286,19 @@
 - mlx.core.sqrt
 
 - mlx.core.square
 
 - mlx.core.squeeze
 
+- mlx.core.stack
 
 - mlx.core.stop_gradient
 
 - mlx.core.subtract
 
 - mlx.core.sum
 
+- mlx.core.swapaxes
 
 - mlx.core.take
 
 - mlx.core.take_along_axis
 
 - mlx.core.tan
 
 - mlx.core.tanh
 
 - mlx.core.transpose
 
+- mlx.core.tri
 
+- mlx.core.tril
 
+- mlx.core.triu
 
 - mlx.core.var
 
 - mlx.core.where
 
 - mlx.core.zeros
 
@@ -316,6 +325,7 @@
 - mlx.core.jvp
 
 - mlx.core.vjp
 
 - mlx.core.vmap
 
+- mlx.core.simplify
 
 
  
 - FFT
@@ -335,48 +345,63 @@
 
 
 - Neural Networks
 - mlx.nn.value_and_grad
 
-- mlx.nn.Embedding
 
-- mlx.nn.ReLU
 
-- mlx.nn.PReLU
 
-- mlx.nn.GELU
 
-- mlx.nn.SiLU
 
-- mlx.nn.Step
 
-- mlx.nn.SELU
 
-- mlx.nn.Mish
 
-- mlx.nn.Linear
 
-- mlx.nn.Conv1d
 
-- mlx.nn.Conv2d
 
-- mlx.nn.LayerNorm
 
-- mlx.nn.RMSNorm
 
-- mlx.nn.GroupNorm
 
-- mlx.nn.RoPE
 
-- mlx.nn.MultiHeadAttention
 
-- mlx.nn.Sequential
 
-- mlx.nn.gelu
 
-- mlx.nn.gelu_approx
 
-- mlx.nn.gelu_fast_approx
 
-- mlx.nn.relu
 
-- mlx.nn.prelu
 
-- mlx.nn.silu
 
-- mlx.nn.step
 
-- mlx.nn.selu
 
-- mlx.nn.mish
 
-- mlx.nn.losses.cross_entropy
 
-- mlx.nn.losses.binary_cross_entropy
 
-- mlx.nn.losses.l1_loss
 
-- mlx.nn.losses.mse_loss
 
-- mlx.nn.losses.nll_loss
 
-- mlx.nn.losses.kl_div_loss
 
+- mlx.nn.Module
 
+- Layers
 
 
-- Optimizers
+- Functions
+
 
+- Loss Functions
+
 
+
+ 
+- Optimizers
 
 
-- Tree Utils
+- Tree Utils
 - mlx.utils.tree_flatten
 
 - mlx.utils.tree_unflatten
 
 - mlx.utils.tree_map
 
@@ -591,8 +616,8 @@ module to concisely define the model architecture.
 positional encoding. [1] In addition, our attention layer will optionally use a
 key/value cache that will be concatenated with the provided keys and values to
 support efficient inference.
-Our implementation uses mlx.nn.Linear for all the projections and
-mlx.nn.RoPE for the positional encoding.
+Our implementation uses mlx.nn.Linear for all the projections and
+mlx.nn.RoPE for the positional encoding.
 import mlx.core as mx
 import mlx.nn as nn
 
@@ -650,7 +675,7 @@ support efficient inference.
 Encoder layer#
 The other component of the Llama model is the encoder layer which uses RMS
 normalization [2] and SwiGLU. [3] For RMS normalization we will use
-mlx.nn.RMSNorm that is already provided in mlx.nn.
+mlx.nn.RMSNorm that is already provided in mlx.nn.
 class LlamaEncoderLayer(nn.Module):
     def __init__(self, dims: int, mlp_dims: int, num_heads: int):
         super().__init__()
@@ -683,7 +708,7 @@ normalization 
 Full model#
 To implement any Llama model we simply have to combine LlamaEncoderLayer
-instances with an mlx.nn.Embedding to embed the input tokens.
+instances with an mlx.nn.Embedding to embed the input tokens.
 class Llama(nn.Module):
     def __init__(
         self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
diff --git a/docs/build/html/examples/mlp.html b/docs/build/html/examples/mlp.html
index e0d5c70bd..b5f6b0989 100644
--- a/docs/build/html/examples/mlp.html
+++ b/docs/build/html/examples/mlp.html
@@ -226,6 +226,7 @@
 
- mlx.core.argsort
 
 - mlx.core.array_equal
 
 - mlx.core.broadcast_to
 
+- mlx.core.ceil
 
 - mlx.core.concatenate
 
 - mlx.core.convolve
 
 - mlx.core.conv1d
 
@@ -239,6 +240,8 @@
 - mlx.core.exp
 
 - mlx.core.expand_dims
 
 - mlx.core.eye
 
+- mlx.core.floor
 
+- mlx.core.flatten
 
 - mlx.core.full
 
 - mlx.core.greater
 
 - mlx.core.greater_equal
 
@@ -259,6 +262,7 @@
 - mlx.core.mean
 
 - mlx.core.min
 
 - mlx.core.minimum
 
+- mlx.core.moveaxis
 
 - mlx.core.multiply
 
 - mlx.core.negative
 
 - mlx.core.ones
 
@@ -282,14 +286,19 @@
 - mlx.core.sqrt
 
 - mlx.core.square
 
 - mlx.core.squeeze
 
+- mlx.core.stack
 
 - mlx.core.stop_gradient
 
 - mlx.core.subtract
 
 - mlx.core.sum
 
+- mlx.core.swapaxes
 
 - mlx.core.take
 
 - mlx.core.take_along_axis
 
 - mlx.core.tan
 
 - mlx.core.tanh
 
 - mlx.core.transpose
 
+- mlx.core.tri
 
+- mlx.core.tril
 
+- mlx.core.triu
 
 - mlx.core.var
 
 - mlx.core.where
 
 - mlx.core.zeros
 
@@ -316,6 +325,7 @@
 - mlx.core.jvp
 
 - mlx.core.vjp
 
 - mlx.core.vmap
 
+- mlx.core.simplify
 
 
 
 - FFT
@@ -335,48 +345,63 @@
 
 
 - Neural Networks
 - mlx.nn.value_and_grad
 
-- mlx.nn.Embedding
 
-- mlx.nn.ReLU
 
-- mlx.nn.PReLU
 
-- mlx.nn.GELU
 
-- mlx.nn.SiLU
 
-- mlx.nn.Step
 
-- mlx.nn.SELU
 
-- mlx.nn.Mish
 
-- mlx.nn.Linear
 
-- mlx.nn.Conv1d
 
-- mlx.nn.Conv2d
 
-- mlx.nn.LayerNorm
 
-- mlx.nn.RMSNorm
 
-- mlx.nn.GroupNorm
 
-- mlx.nn.RoPE
 
-- mlx.nn.MultiHeadAttention
 
-- mlx.nn.Sequential
 
-- mlx.nn.gelu
 
-- mlx.nn.gelu_approx
 
-- mlx.nn.gelu_fast_approx
 
-- mlx.nn.relu
 
-- mlx.nn.prelu
 
-- mlx.nn.silu
 
-- mlx.nn.step
 
-- mlx.nn.selu
 
-- mlx.nn.mish
 
-- mlx.nn.losses.cross_entropy
 
-- mlx.nn.losses.binary_cross_entropy
 
-- mlx.nn.losses.l1_loss
 
-- mlx.nn.losses.mse_loss
 
-- mlx.nn.losses.nll_loss
 
-- mlx.nn.losses.kl_div_loss
 
+- mlx.nn.Module
 
+- Layers
 
 
-- Optimizers
+- Functions
+
 
+- Loss Functions
+
 
+
+ 
+- Optimizers
 
 
-- Tree Utils
+- Tree Utils
 - mlx.utils.tree_flatten
 
 - mlx.utils.tree_unflatten
 
 - mlx.utils.tree_map
 
@@ -568,11 +593,11 @@ multi-layer perceptron to classify MNIST.
 
 
 
 
 
 The model is defined as the MLP class which inherits from
-mlx.nn.Module. We follow the standard idiom to make a new module:
+mlx.nn.Module. We follow the standard idiom to make a new module:
 
 Define an __init__ where the parameters and/or submodules are setup. See
 the Module class docs for more information on how
-mlx.nn.Module registers parameters.
 
+mlx.nn.Module registers parameters.
 Define a __call__ where the computation is implemented.
 
 
 class MLP(nn.Module):
@@ -605,7 +630,9 @@ set:
     return mx.mean(mx.argmax(model(X), axis=1) == y)
 
 
-Next, setup the problem parameters and load the data:
+Next, setup the problem parameters and load the data. To load the data, you need our
+mnist data loader, which
+we will import as mnist.
 num_layers = 2
 hidden_dim = 32
 num_classes = 10
diff --git a/docs/build/html/genindex.html b/docs/build/html/genindex.html
index edd8f7d6e..7a9ba3568 100644
--- a/docs/build/html/genindex.html
+++ b/docs/build/html/genindex.html
@@ -223,6 +223,7 @@
 
- mlx.core.argsort
 
 - mlx.core.array_equal
 
 - mlx.core.broadcast_to
 
+- mlx.core.ceil
 
 - mlx.core.concatenate
 
 - mlx.core.convolve
 
 - mlx.core.conv1d
 
@@ -236,6 +237,8 @@
 - mlx.core.exp
 
 - mlx.core.expand_dims
 
 - mlx.core.eye
 
+- mlx.core.floor
 
+- mlx.core.flatten
 
 - mlx.core.full
 
 - mlx.core.greater
 
 - mlx.core.greater_equal
 
@@ -256,6 +259,7 @@
 - mlx.core.mean
 
 - mlx.core.min
 
 - mlx.core.minimum
 
+- mlx.core.moveaxis
 
 - mlx.core.multiply
 
 - mlx.core.negative
 
 - mlx.core.ones
 
@@ -279,14 +283,19 @@
 - mlx.core.sqrt
 
 - mlx.core.square
 
 - mlx.core.squeeze
 
+- mlx.core.stack
 
 - mlx.core.stop_gradient
 
 - mlx.core.subtract
 
 - mlx.core.sum
 
+- mlx.core.swapaxes
 
 - mlx.core.take
 
 - mlx.core.take_along_axis
 
 - mlx.core.tan
 
 - mlx.core.tanh
 
 - mlx.core.transpose
 
+- mlx.core.tri
 
+- mlx.core.tril
 
+- mlx.core.triu
 
 - mlx.core.var
 
 - mlx.core.where
 
 - mlx.core.zeros
 
@@ -313,6 +322,7 @@
 - mlx.core.jvp
 
 - mlx.core.vjp
 
 - mlx.core.vmap
 
+- mlx.core.simplify
 
 
 
 - FFT
@@ -332,48 +342,63 @@
 
 
 - Neural Networks
 - mlx.nn.value_and_grad
 
-- mlx.nn.Embedding
 
-- mlx.nn.ReLU
 
-- mlx.nn.PReLU
 
-- mlx.nn.GELU
 
-- mlx.nn.SiLU
 
-- mlx.nn.Step
 
-- mlx.nn.SELU
 
-- mlx.nn.Mish
 
-- mlx.nn.Linear
 
-- mlx.nn.Conv1d
 
-- mlx.nn.Conv2d
 
-- mlx.nn.LayerNorm
 
-- mlx.nn.RMSNorm
 
-- mlx.nn.GroupNorm
 
-- mlx.nn.RoPE
 
-- mlx.nn.MultiHeadAttention
 
-- mlx.nn.Sequential
 
-- mlx.nn.gelu
 
-- mlx.nn.gelu_approx
 
-- mlx.nn.gelu_fast_approx
 
-- mlx.nn.relu
 
-- mlx.nn.prelu
 
-- mlx.nn.silu
 
-- mlx.nn.step
 
-- mlx.nn.selu
 
-- mlx.nn.mish
 
-- mlx.nn.losses.cross_entropy
 
-- mlx.nn.losses.binary_cross_entropy
 
-- mlx.nn.losses.l1_loss
 
-- mlx.nn.losses.mse_loss
 
-- mlx.nn.losses.nll_loss
 
-- mlx.nn.losses.kl_div_loss
 
+- mlx.nn.Module
 
+- Layers
 
 
-- Optimizers
+- Functions
+
 
+- Loss Functions
+
 
+
+ 
+- Optimizers
 
 
-- Tree Utils
+- Tree Utils
 - mlx.utils.tree_flatten
 
 - mlx.utils.tree_unflatten
 
 - mlx.utils.tree_map
 
@@ -544,6 +569,8 @@ document.write(`
         - (mlx.core.Dtype method)
 
 
         - (mlx.core.Stream method)
+
 
+        - (mlx.nn.Module method)
 
 
       
 
   
@@ -558,7 +585,15 @@ document.write(`
          - (mlx.core.array method)
 
 
       
 
+      - AdaDelta (class in mlx.optimizers)
+
 
+      - Adagrad (class in mlx.optimizers)
+
 
       - Adam (class in mlx.optimizers)
+
 
+      - Adamax (class in mlx.optimizers)
+
 
+      - AdamW (class in mlx.optimizers)
 
 
       - add() (in module mlx.core)
 
 
@@ -576,16 +611,12 @@ document.write(`
         - (mlx.core.array method)
 
 
       
-      - apply() (mlx.nn.Module method)
-
 
-      - apply_to_modules() (mlx.nn.Module method)
-
 
       - arange() (in module mlx.core)
-
 
-      - arccos() (in module mlx.core)
 
 
   
   
+      - arccos() (in module mlx.core)
+
 
       - arccosh() (in module mlx.core)
 
 
       - arcsin() (in module mlx.core)
@@ -628,7 +659,7 @@ document.write(`
 
 
   
 
   
-      - binary_cross_entropy (class in mlx.nn.losses)
+      
 - binary_cross_entropy (class in mlx.nn.losses)
 
 
       - broadcast_to() (in module mlx.core)
 
 
@@ -640,15 +671,15 @@ document.write(`
    
   
@@ -664,7 +695,7 @@ document.write(`
       
       - cosh() (in module mlx.core)
 
 
-      - cross_entropy (class in mlx.nn.losses)
+      
 - cross_entropy (class in mlx.nn.losses)
 
 
    
 
@@ -692,7 +723,7 @@ document.write(`
 E