From ffe51a69ca3d6def78074d09268a3b87da47433d Mon Sep 17 00:00:00 2001
From: Awni Hannun
@@ -335,48 +345,63 @@
+
+
+
@@ -334,48 +344,63 @@
+
+
+
We see some modest improvements right away!
This operation is now good to be used to build other operations,
-in mlx.nn.Module
calls, and also as a part of graph
+in mlx.nn.Module
calls, and also as a part of graph
transformations such as grad()
and simplify()
!
diff --git a/docs/build/html/examples/linear_regression.html b/docs/build/html/examples/linear_regression.html
index dadfc4423..f81c9f4cd 100644
--- a/docs/build/html/examples/linear_regression.html
+++ b/docs/build/html/examples/linear_regression.html
@@ -226,6 +226,7 @@
mlx.core.argsort
mlx.core.array_equal
mlx.core.broadcast_to
+mlx.core.ceil
mlx.core.concatenate
mlx.core.convolve
mlx.core.conv1d
@@ -239,6 +240,8 @@
mlx.core.exp
mlx.core.expand_dims
mlx.core.eye
+mlx.core.floor
+mlx.core.flatten
mlx.core.full
mlx.core.greater
mlx.core.greater_equal
@@ -259,6 +262,7 @@
mlx.core.mean
mlx.core.min
mlx.core.minimum
+mlx.core.moveaxis
mlx.core.multiply
mlx.core.negative
mlx.core.ones
@@ -282,14 +286,19 @@
mlx.core.sqrt
mlx.core.square
mlx.core.squeeze
+mlx.core.stack
mlx.core.stop_gradient
mlx.core.subtract
mlx.core.sum
+mlx.core.swapaxes
mlx.core.take
mlx.core.take_along_axis
mlx.core.tan
mlx.core.tanh
mlx.core.transpose
+mlx.core.tri
+mlx.core.tril
+mlx.core.triu
mlx.core.var
mlx.core.where
mlx.core.zeros
@@ -316,6 +325,7 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
+mlx.core.simplify
FFT
@@ -335,48 +345,63 @@
Neural Networks
- mlx.nn.value_and_grad
-- mlx.nn.Embedding
-- mlx.nn.ReLU
-- mlx.nn.PReLU
-- mlx.nn.GELU
-- mlx.nn.SiLU
-- mlx.nn.Step
-- mlx.nn.SELU
-- mlx.nn.Mish
-- mlx.nn.Linear
-- mlx.nn.Conv1d
-- mlx.nn.Conv2d
-- mlx.nn.LayerNorm
-- mlx.nn.RMSNorm
-- mlx.nn.GroupNorm
-- mlx.nn.RoPE
-- mlx.nn.MultiHeadAttention
-- mlx.nn.Sequential
-- mlx.nn.gelu
-- mlx.nn.gelu_approx
-- mlx.nn.gelu_fast_approx
-- mlx.nn.relu
-- mlx.nn.prelu
-- mlx.nn.silu
-- mlx.nn.step
-- mlx.nn.selu
-- mlx.nn.mish
-- mlx.nn.losses.cross_entropy
-- mlx.nn.losses.binary_cross_entropy
-- mlx.nn.losses.l1_loss
-- mlx.nn.losses.mse_loss
-- mlx.nn.losses.nll_loss
-- mlx.nn.losses.kl_div_loss
+- mlx.nn.Module
+- Layers
-- Optimizers
+- Functions
+
+- Loss Functions
+
+
+
+- Optimizers
-- Tree Utils
+- Tree Utils
- mlx.utils.tree_flatten
- mlx.utils.tree_unflatten
- mlx.utils.tree_map
diff --git a/docs/build/html/examples/llama-inference.html b/docs/build/html/examples/llama-inference.html
index 7ad78bcad..f102d4087 100644
--- a/docs/build/html/examples/llama-inference.html
+++ b/docs/build/html/examples/llama-inference.html
@@ -226,6 +226,7 @@
- mlx.core.argsort
- mlx.core.array_equal
- mlx.core.broadcast_to
+- mlx.core.ceil
- mlx.core.concatenate
- mlx.core.convolve
- mlx.core.conv1d
@@ -239,6 +240,8 @@
- mlx.core.exp
- mlx.core.expand_dims
- mlx.core.eye
+- mlx.core.floor
+- mlx.core.flatten
- mlx.core.full
- mlx.core.greater
- mlx.core.greater_equal
@@ -259,6 +262,7 @@
- mlx.core.mean
- mlx.core.min
- mlx.core.minimum
+- mlx.core.moveaxis
- mlx.core.multiply
- mlx.core.negative
- mlx.core.ones
@@ -282,14 +286,19 @@
- mlx.core.sqrt
- mlx.core.square
- mlx.core.squeeze
+- mlx.core.stack
- mlx.core.stop_gradient
- mlx.core.subtract
- mlx.core.sum
+- mlx.core.swapaxes
- mlx.core.take
- mlx.core.take_along_axis
- mlx.core.tan
- mlx.core.tanh
- mlx.core.transpose
+- mlx.core.tri
+- mlx.core.tril
+- mlx.core.triu
- mlx.core.var
- mlx.core.where
- mlx.core.zeros
@@ -316,6 +325,7 @@
- mlx.core.jvp
- mlx.core.vjp
- mlx.core.vmap
+- mlx.core.simplify
- FFT
@@ -335,48 +345,63 @@
- Neural Networks
- mlx.nn.value_and_grad
-- mlx.nn.Embedding
-- mlx.nn.ReLU
-- mlx.nn.PReLU
-- mlx.nn.GELU
-- mlx.nn.SiLU
-- mlx.nn.Step
-- mlx.nn.SELU
-- mlx.nn.Mish
-- mlx.nn.Linear
-- mlx.nn.Conv1d
-- mlx.nn.Conv2d
-- mlx.nn.LayerNorm
-- mlx.nn.RMSNorm
-- mlx.nn.GroupNorm
-- mlx.nn.RoPE
-- mlx.nn.MultiHeadAttention
-- mlx.nn.Sequential
-- mlx.nn.gelu
-- mlx.nn.gelu_approx
-- mlx.nn.gelu_fast_approx
-- mlx.nn.relu
-- mlx.nn.prelu
-- mlx.nn.silu
-- mlx.nn.step
-- mlx.nn.selu
-- mlx.nn.mish
-- mlx.nn.losses.cross_entropy
-- mlx.nn.losses.binary_cross_entropy
-- mlx.nn.losses.l1_loss
-- mlx.nn.losses.mse_loss
-- mlx.nn.losses.nll_loss
-- mlx.nn.losses.kl_div_loss
+- mlx.nn.Module
+- Layers
-- Optimizers
+- Functions
+
+- Loss Functions
+
+
+
+- Optimizers
-- Tree Utils
+- Tree Utils
- mlx.utils.tree_flatten
- mlx.utils.tree_unflatten
- mlx.utils.tree_map
@@ -591,8 +616,8 @@ module to concisely define the model architecture.
positional encoding. [1] In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.
-Our implementation uses mlx.nn.Linear
for all the projections and
-mlx.nn.RoPE
for the positional encoding.
+Our implementation uses mlx.nn.Linear
for all the projections and
+mlx.nn.RoPE
for the positional encoding.
import mlx.core as mx
import mlx.nn as nn
@@ -650,7 +675,7 @@ support efficient inference.
Encoder layer#
The other component of the Llama model is the encoder layer which uses RMS
normalization [2] and SwiGLU. [3] For RMS normalization we will use
-mlx.nn.RMSNorm
that is already provided in mlx.nn
.
+mlx.nn.RMSNorm
that is already provided in mlx.nn
.
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
@@ -683,7 +708,7 @@ normalization
Full model#
To implement any Llama model we simply have to combine LlamaEncoderLayer
-instances with an mlx.nn.Embedding
to embed the input tokens.
+instances with an mlx.nn.Embedding
to embed the input tokens.
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
diff --git a/docs/build/html/examples/mlp.html b/docs/build/html/examples/mlp.html
index e0d5c70bd..b5f6b0989 100644
--- a/docs/build/html/examples/mlp.html
+++ b/docs/build/html/examples/mlp.html
@@ -226,6 +226,7 @@
- mlx.core.argsort
- mlx.core.array_equal
- mlx.core.broadcast_to
+- mlx.core.ceil
- mlx.core.concatenate
- mlx.core.convolve
- mlx.core.conv1d
@@ -239,6 +240,8 @@
- mlx.core.exp
- mlx.core.expand_dims
- mlx.core.eye
+- mlx.core.floor
+- mlx.core.flatten
- mlx.core.full
- mlx.core.greater
- mlx.core.greater_equal
@@ -259,6 +262,7 @@
- mlx.core.mean
- mlx.core.min
- mlx.core.minimum
+- mlx.core.moveaxis
- mlx.core.multiply
- mlx.core.negative
- mlx.core.ones
@@ -282,14 +286,19 @@
- mlx.core.sqrt
- mlx.core.square
- mlx.core.squeeze
+- mlx.core.stack
- mlx.core.stop_gradient
- mlx.core.subtract
- mlx.core.sum
+- mlx.core.swapaxes
- mlx.core.take
- mlx.core.take_along_axis
- mlx.core.tan
- mlx.core.tanh
- mlx.core.transpose
+- mlx.core.tri
+- mlx.core.tril
+- mlx.core.triu
- mlx.core.var
- mlx.core.where
- mlx.core.zeros
@@ -316,6 +325,7 @@
- mlx.core.jvp
- mlx.core.vjp
- mlx.core.vmap
+- mlx.core.simplify
- FFT
@@ -335,48 +345,63 @@
- Neural Networks
- mlx.nn.value_and_grad
-- mlx.nn.Embedding
-- mlx.nn.ReLU
-- mlx.nn.PReLU
-- mlx.nn.GELU
-- mlx.nn.SiLU
-- mlx.nn.Step
-- mlx.nn.SELU
-- mlx.nn.Mish
-- mlx.nn.Linear
-- mlx.nn.Conv1d
-- mlx.nn.Conv2d
-- mlx.nn.LayerNorm
-- mlx.nn.RMSNorm
-- mlx.nn.GroupNorm
-- mlx.nn.RoPE
-- mlx.nn.MultiHeadAttention
-- mlx.nn.Sequential
-- mlx.nn.gelu
-- mlx.nn.gelu_approx
-- mlx.nn.gelu_fast_approx
-- mlx.nn.relu
-- mlx.nn.prelu
-- mlx.nn.silu
-- mlx.nn.step
-- mlx.nn.selu
-- mlx.nn.mish
-- mlx.nn.losses.cross_entropy
-- mlx.nn.losses.binary_cross_entropy
-- mlx.nn.losses.l1_loss
-- mlx.nn.losses.mse_loss
-- mlx.nn.losses.nll_loss
-- mlx.nn.losses.kl_div_loss
+- mlx.nn.Module
+- Layers
-- Optimizers
+- Functions
+
+- Loss Functions
+
+
+
+- Optimizers
-- Tree Utils
+- Tree Utils
- mlx.utils.tree_flatten
- mlx.utils.tree_unflatten
- mlx.utils.tree_map
@@ -568,11 +593,11 @@ multi-layer perceptron to classify MNIST.
The model is defined as the MLP
class which inherits from
-mlx.nn.Module
. We follow the standard idiom to make a new module:
+mlx.nn.Module
. We follow the standard idiom to make a new module:
Define an __init__
where the parameters and/or submodules are setup. See
the Module class docs for more information on how
-mlx.nn.Module
registers parameters.
+mlx.nn.Module
registers parameters.
Define a __call__
where the computation is implemented.
class MLP(nn.Module):
@@ -605,7 +630,9 @@ set:
return mx.mean(mx.argmax(model(X), axis=1) == y)
-Next, setup the problem parameters and load the data:
+Next, setup the problem parameters and load the data. To load the data, you need our
+mnist data loader, which
+we will import as mnist.
num_layers = 2
hidden_dim = 32
num_classes = 10
diff --git a/docs/build/html/genindex.html b/docs/build/html/genindex.html
index edd8f7d6e..7a9ba3568 100644
--- a/docs/build/html/genindex.html
+++ b/docs/build/html/genindex.html
@@ -223,6 +223,7 @@
- mlx.core.argsort
- mlx.core.array_equal
- mlx.core.broadcast_to
+- mlx.core.ceil
- mlx.core.concatenate
- mlx.core.convolve
- mlx.core.conv1d
@@ -236,6 +237,8 @@
- mlx.core.exp
- mlx.core.expand_dims
- mlx.core.eye
+- mlx.core.floor
+- mlx.core.flatten
- mlx.core.full
- mlx.core.greater
- mlx.core.greater_equal
@@ -256,6 +259,7 @@
- mlx.core.mean
- mlx.core.min
- mlx.core.minimum
+- mlx.core.moveaxis
- mlx.core.multiply
- mlx.core.negative
- mlx.core.ones
@@ -279,14 +283,19 @@
- mlx.core.sqrt
- mlx.core.square
- mlx.core.squeeze
+- mlx.core.stack
- mlx.core.stop_gradient
- mlx.core.subtract
- mlx.core.sum
+- mlx.core.swapaxes
- mlx.core.take
- mlx.core.take_along_axis
- mlx.core.tan
- mlx.core.tanh
- mlx.core.transpose
+- mlx.core.tri
+- mlx.core.tril
+- mlx.core.triu
- mlx.core.var
- mlx.core.where
- mlx.core.zeros
@@ -313,6 +322,7 @@
- mlx.core.jvp
- mlx.core.vjp
- mlx.core.vmap
+- mlx.core.simplify
- FFT
@@ -332,48 +342,63 @@
- Neural Networks
- mlx.nn.value_and_grad
-- mlx.nn.Embedding
-- mlx.nn.ReLU
-- mlx.nn.PReLU
-- mlx.nn.GELU
-- mlx.nn.SiLU
-- mlx.nn.Step
-- mlx.nn.SELU
-- mlx.nn.Mish
-- mlx.nn.Linear
-- mlx.nn.Conv1d
-- mlx.nn.Conv2d
-- mlx.nn.LayerNorm
-- mlx.nn.RMSNorm
-- mlx.nn.GroupNorm
-- mlx.nn.RoPE
-- mlx.nn.MultiHeadAttention
-- mlx.nn.Sequential
-- mlx.nn.gelu
-- mlx.nn.gelu_approx
-- mlx.nn.gelu_fast_approx
-- mlx.nn.relu
-- mlx.nn.prelu
-- mlx.nn.silu
-- mlx.nn.step
-- mlx.nn.selu
-- mlx.nn.mish
-- mlx.nn.losses.cross_entropy
-- mlx.nn.losses.binary_cross_entropy
-- mlx.nn.losses.l1_loss
-- mlx.nn.losses.mse_loss
-- mlx.nn.losses.nll_loss
-- mlx.nn.losses.kl_div_loss
+- mlx.nn.Module
+- Layers
-- Optimizers
+- Functions
+
+- Loss Functions
+
+
+
+- Optimizers
-- Tree Utils
+- Tree Utils
- mlx.utils.tree_flatten
- mlx.utils.tree_unflatten
- mlx.utils.tree_map
@@ -544,6 +569,8 @@ document.write(`
- (mlx.core.Dtype method)
- (mlx.core.Stream method)
+
+ - (mlx.nn.Module method)
@@ -558,7 +585,15 @@ document.write(`
- (mlx.core.array method)
+ - AdaDelta (class in mlx.optimizers)
+
+ - Adagrad (class in mlx.optimizers)
+
- Adam (class in mlx.optimizers)
+
+ - Adamax (class in mlx.optimizers)
+
+ - AdamW (class in mlx.optimizers)
- add() (in module mlx.core)
@@ -576,16 +611,12 @@ document.write(`
- (mlx.core.array method)
- - apply() (mlx.nn.Module method)
-
- - apply_to_modules() (mlx.nn.Module method)
-
- arange() (in module mlx.core)
-
- - arccos() (in module mlx.core)
+ - arccos() (in module mlx.core)
+
- arccosh() (in module mlx.core)
- arcsin() (in module mlx.core)
@@ -628,7 +659,7 @@ document.write(`
- - binary_cross_entropy (class in mlx.nn.losses)
+
- binary_cross_entropy (class in mlx.nn.losses)
- broadcast_to() (in module mlx.core)
@@ -640,15 +671,15 @@ document.write(`
@@ -664,7 +695,7 @@ document.write(`
- cosh() (in module mlx.core)
- - cross_entropy (class in mlx.nn.losses)
+
- cross_entropy (class in mlx.nn.losses)
@@ -692,7 +723,7 @@ document.write(`
E