mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
style nits in docs
This commit is contained in:
parent
47a64c480b
commit
913fd33c9c
@ -9,7 +9,7 @@ Layers
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
Sequential
|
||||
ReLU
|
||||
PReLU
|
||||
GELU
|
||||
@ -17,17 +17,19 @@ Layers
|
||||
Step
|
||||
SELU
|
||||
Mish
|
||||
Embedding
|
||||
Linear
|
||||
QuantizedLinear
|
||||
Conv1d
|
||||
Conv2d
|
||||
BatchNorm
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
QuantizedLinear
|
||||
Dropout
|
||||
Dropout2d
|
||||
|
||||
Transformer
|
||||
MultiHeadAttention
|
||||
ALiBi
|
||||
RoPE
|
||||
SinusoidalPositionalEncoding
|
||||
|
@ -14,6 +14,7 @@ from mlx.nn.layers.activations import (
|
||||
SiLU,
|
||||
Softplus,
|
||||
Step,
|
||||
Tanh,
|
||||
celu,
|
||||
elu,
|
||||
gelu,
|
||||
@ -29,6 +30,7 @@ from mlx.nn.layers.activations import (
|
||||
silu,
|
||||
softplus,
|
||||
step,
|
||||
tanh,
|
||||
)
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
@ -41,6 +43,7 @@ from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalE
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
Transformer,
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
|
@ -179,12 +179,12 @@ def selu(x):
|
||||
|
||||
|
||||
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
r"""Applies the element-wise function:
|
||||
r"""Applies the element-wise parametric ReLU.
|
||||
|
||||
.. math::
|
||||
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
||||
|
||||
Here :math:`a` is an array.
|
||||
where :math:`a` is an array.
|
||||
"""
|
||||
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
||||
|
||||
|
@ -8,21 +8,22 @@ from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class RoPE(Module):
|
||||
"""Implements the rotary positional encoding [1].
|
||||
"""Implements the rotary positional encoding.
|
||||
|
||||
The traditional implementation rotates consecutive pairs of elements in the
|
||||
feature dimension while the default implementation rotates pairs with
|
||||
stride half the feature dimensions for efficiency.
|
||||
|
||||
[1]: https://arxiv.org/abs/2104.09864
|
||||
For more details see `RoFormer: Enhanced Transformer with Rotary Position
|
||||
Embedding <https://arxiv.org/abs/2104.09864>`_.
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool, optional): If set to True choose the traditional
|
||||
implementation which is slightly less efficient. Default: ``False``
|
||||
implementation which is slightly less efficient. Default: ``False``.
|
||||
base (float, optional): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings. Default: ``10000``
|
||||
each dimension in the positional encodings. Default: ``10000``.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||
@ -89,19 +90,23 @@ class RoPE(Module):
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
"""Implements sinusoidal positional encoding similar to [1].
|
||||
r"""Implements sinusoidal positional encoding.
|
||||
|
||||
[1]: https://arxiv.org/abs/1706.03762
|
||||
For more details see the paper `Attention Is All You Need
|
||||
<https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Args:
|
||||
dims (int): The dimensionality of the resulting positional embeddings.
|
||||
min_freq (float): The minimum frequency expected (default: 0.0001)
|
||||
max_freq (float): The maximum frequency expected (default: 1)
|
||||
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
|
||||
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
|
||||
instead of the other way around (default: False)
|
||||
full_turns (bool): If set to True multiply the frequencies
|
||||
with ``2 pi`` (default: False)
|
||||
min_freq (float, optional): The minimum frequency expected. Default:
|
||||
``0.0001``.
|
||||
max_freq (float, optional): The maximum frequency expected. Default:
|
||||
``1``.
|
||||
scale (float, optional): A multiplicative scale for the embeddings.
|
||||
Default: ``sqrt(dims//2)``.
|
||||
cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
|
||||
instead of the reverse. Default: ``False``.
|
||||
full_turns (bool, optional): If ``True`` multiply the frequencies with
|
||||
:math:`2\pi`. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -14,26 +14,34 @@ from mlx.nn.layers.normalization import LayerNorm
|
||||
class MultiHeadAttention(Module):
|
||||
"""Implements the scaled dot product attention with multiple heads.
|
||||
|
||||
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
|
||||
new values by aggregating information from the input values according to
|
||||
the similarities of the input queries and keys.
|
||||
Given inputs for queries, keys and values the ``MultiHeadAttention``
|
||||
produces new values by aggregating information from the input values
|
||||
according to the similarities of the input queries and keys.
|
||||
|
||||
All inputs as well as the output are linearly projected without biases.
|
||||
All inputs as well as the output are linearly projected without biases by
|
||||
default.
|
||||
|
||||
MultiHeadAttention also expects an additive attention mask that should be
|
||||
broadcastable with (batch, num_heads, # queries, # keys). The mask should
|
||||
have ``-inf`` or very negative numbers to the positions that should *not* be
|
||||
attended to.
|
||||
``MultiHeadAttention`` also takes an optional additive attention mask that
|
||||
should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
|
||||
mask should have ``-inf`` or very large negative numbers at the positions
|
||||
that should *not* be attended to.
|
||||
|
||||
Args:
|
||||
dims (int): The model dimensions. If no other dims are provided then
|
||||
dims is used for queries, keys, values and the output.
|
||||
num_heads (int): How many attention heads to use
|
||||
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
|
||||
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
|
||||
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
|
||||
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
|
||||
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
|
||||
dims (int): The model dimensions. This is also the default
|
||||
value for the queries, keys, values, and the output.
|
||||
num_heads (int): The number of attention heads to use.
|
||||
query_input_dims (int, optional): The input dimensions of the queries.
|
||||
Default: ``dims``.
|
||||
key_input_dims (int, optional): The input dimensions of the keys.
|
||||
Default: ``dims``.
|
||||
value_input_dims (int, optional): The input dimensions of the values.
|
||||
Default: ``key_input_dims``.
|
||||
value_dims (int, optional): The dimensions of the values after the
|
||||
projection. Default: ``dims``.
|
||||
value_output_dims (int, optional): The dimensions the new values will
|
||||
be projected to. Default: ``dims``.
|
||||
bias (bool, optional): Whether or not to use a bias in the projections.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -51,7 +59,8 @@ class MultiHeadAttention(Module):
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
||||
"The input feature dimensions should be divisible by the "
|
||||
f"number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
@ -171,9 +180,7 @@ class TransformerEncoder(Module):
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(Module):
|
||||
@ -261,32 +268,45 @@ class TransformerDecoder(Module):
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
for l in self.layers:
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class Transformer(Module):
|
||||
"""
|
||||
Implements a standard Transformer model based on the paper "Attention Is All You Need".
|
||||
Implements a standard Transformer model.
|
||||
|
||||
The Transformer model consists of an encoder and a decoder. The encoder processes
|
||||
the input sequence and the decoder generates the output sequence. The interaction
|
||||
between encoder and decoder happens through the attention mechanism.
|
||||
The implementation is based on `Attention Is All You Need
|
||||
<https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
The Transformer model contains an encoder and a decoder. The encoder
|
||||
processes the input sequence and the decoder generates the output sequence.
|
||||
The interaction between encoder and decoder happens through the attention
|
||||
mechanism.
|
||||
|
||||
Args:
|
||||
dims (int): The number of expected features in the encoder/decoder inputs (default: 512)
|
||||
num_heads (int): The number of heads in the multi-head attention models (default: 8)
|
||||
num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder (default: 6)
|
||||
num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder (default: 6)
|
||||
mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer,
|
||||
Defaults to 4*dims if not provided (default: None)
|
||||
dropout (float): The dropout value for Transformer encoder/decoder (default: 0.0)
|
||||
activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer (default: relu)
|
||||
custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder (default: None)
|
||||
custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder (default: None)
|
||||
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
|
||||
other attention and feedforward operations, otherwise after (default: False)
|
||||
dims (int, optional): The number of expected features in the
|
||||
encoder/decoder inputs. Default: ``512``.
|
||||
num_heads (int, optional): The number of attention heads. Default:
|
||||
``8``.
|
||||
num_encoder_layers (int, optional): The number of encoder layers in the
|
||||
Transformer encoder. Default: ``6``.
|
||||
num_decoder_layers (int, optional): The number of decoder layers in the
|
||||
Transformer decoder. Default: ``6``.
|
||||
mlp_dims (int, optional): The hidden dimension of the MLP block in each
|
||||
Transformer layer. Defaults to ``4*dims`` if not provided. Default:
|
||||
``None``.
|
||||
dropout (float, optional): The dropout value for the Transformer
|
||||
encoder and decoder. Dropout is used after each attention layer and
|
||||
the activation in the MLP layer. Default: ``0.0``.
|
||||
activation (function, optional): the activation function for the MLP
|
||||
hidden layer. Default: :func:`mlx.nn.relu`.
|
||||
custom_encoder (nn.Module, optional): A custom encoder to replace the
|
||||
standard Transformer encoder. Default: ``None``.
|
||||
custom_decoder (nn.Module, optional): A custom decoder to replace the
|
||||
standard Transformer decoder. Default: ``None``.
|
||||
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
||||
will perform layer normalization before attention and MLP
|
||||
operations, otherwise after. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -331,6 +351,4 @@ class Transformer(Module):
|
||||
|
||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||
memory = self.encoder(src, src_mask)
|
||||
output = self.decoder(tgt, memory, tgt_mask, memory_mask)
|
||||
|
||||
return output
|
||||
return self.decoder(tgt, memory, tgt_mask, memory_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user