angelos's commit files

This commit is contained in:
Angelos Katharopoulos
2023-11-29 10:42:59 -08:00
parent 8ca7f9e8e9
commit d1f86272a2
56 changed files with 12350 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from mlx.nn.layers import *
from mlx.nn import losses
from mlx.nn.utils import value_and_grad

View File

@@ -0,0 +1,23 @@
from mlx.nn.layers.base import Module
from mlx.nn.layers.activations import (
GELU,
ReLU,
SiLU,
gelu,
gelu_approx,
gelu_fast_approx,
relu,
silu,
)
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.transformer import (
MultiHeadAttention,
TransformerEncoder,
TransformerEncoderLayer,
)

View File

@@ -0,0 +1,129 @@
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
def _make_activation_module(f):
def decorator(klass):
klass.__doc__ = f.__doc__
klass.__call__ = lambda self, x: f(x)
return klass
return decorator
def relu(x):
"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
"""
return mx.maximum(x, 0)
def silu(x):
r"""Applies the Sigmoid Linear Unit.
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
the logistic sigmoid.
"""
return x * mx.sigmoid(x)
def gelu(x):
"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster
approximations.
"""
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
def gelu_approx(x):
r"""An approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.0003` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
"""
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
def gelu_fast_approx(x):
r"""A fast approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.015` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.773 x\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
"""
return x * mx.sigmoid(1.773 * x)
@_make_activation_module(relu)
class ReLU(Module):
pass
@_make_activation_module(silu)
class SiLU(Module):
pass
class GELU(Module):
r"""Applies the Gaussian Error Linear Units.
.. math::
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
However, if ``approx`` is set to 'precise' or 'fast' it applies
.. math::
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
respectively.
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds.
Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
"""
def __init__(self, approx="none"):
super().__init__()
if approx == "none":
self._act = gelu
elif approx == "precise":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError(
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
)
def __call__(self, x):
return self._act(x)

6
python/mlx/nn/losses.py Normal file
View File

@@ -0,0 +1,6 @@
import mlx.core as mx
def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
return mx.logsumexp(logits, axis=axis) - score

136
python/mlx/utils.py Normal file
View File

@@ -0,0 +1,136 @@
def tree_map(fn, tree, *rest):
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
returns a new collection with the results.
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
and the corresponding leaves are provided as extra positional arguments to
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
than to :func:`map`.
.. code-block:: python
import mlx.nn as nn
from mlx.utils import tree_map
model = nn.Linear(10, 10)
print(model.parameters().keys())
# dict_keys(['weight', 'bias'])
# square the parameters
model.update(tree_map(lambda x: x*x, model.parameters()))
Args:
fn (Callable): The function that processes the leaves of the tree
tree (Any): The main python tree that will be iterated upon
rest (Tuple[Any]): Extra trees to be iterated together with tree
Returns:
A python tree with the new values returned by ``fn``.
"""
if isinstance(tree, list):
return [
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
]
elif isinstance(tree, tuple):
return tuple(
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
)
elif isinstance(tree, dict):
return {
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
}
else:
return fn(tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None):
"""Flattens a python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
complexity.
.. code-block:: python
from mlx.utils import tree_flatten
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)]
.. note::
Dictionaries should have keys that are valid python identifiers.
Args:
tree (Any): The python tree to be flattened.
prefix (str): A prefix to use for the keys. The first character is
always discarded.
is_leaf (Callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
Returns:
List[Tuple[str, Any]]: The flat representation of the python tree.
"""
flat_tree = []
if is_leaf is None or not is_leaf(tree):
if isinstance(tree, (list, tuple)):
for i, t in enumerate(tree):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return flat_tree
if isinstance(tree, dict):
for k, t in tree.items():
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return flat_tree
return [(prefix[1:], tree)]
def tree_unflatten(tree):
"""Recreate a python tree from its flat representation.
.. code-block:: python
from mlx.utils import tree_unflatten
d = tree_unflatten([("hello.world", 42)])
print(d)
# {"hello": {"world": 42}}
Args:
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A python tree.
"""
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
try:
int(tree[0][0].split(".", maxsplit=1)[0])
is_list = True
except ValueError:
is_list = False
# collect children
children = {}
for key, value in tree:
current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0]
if current_idx not in children:
children[current_idx] = []
children[current_idx].append((next_idx, value))
# recursively map them to the original container
if is_list:
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
while i > len(l):
l.append({})
l.append(tree_unflatten(children[k]))
return l
else:
return {k: tree_unflatten(v) for k, v in children.items()}